From 0ae1ae54a8facd289d6e8b7f842b02f1b7f12f1a Mon Sep 17 00:00:00 2001
From: Christopher Anderson
Date: Fri, 20 Jun 2025 21:27:01 +1000
Subject: [PATCH 001/155] AMD installation instructions for RDNA3.x+Windows
---
docs/AMD-INSTALLATION.md | 146 +++++++++++++++++++++++++++++++++++++++
1 file changed, 146 insertions(+)
create mode 100644 docs/AMD-INSTALLATION.md
diff --git a/docs/AMD-INSTALLATION.md b/docs/AMD-INSTALLATION.md
new file mode 100644
index 000000000..4f05589eb
--- /dev/null
+++ b/docs/AMD-INSTALLATION.md
@@ -0,0 +1,146 @@
+# Installation Guide
+
+This guide covers installation for specific RDNA3 and RDNA3.5 AMD CPUs (APUs) and GPUs
+running under Windows.
+
+tl;dr: Radeon RX 7900 GOOD, RX 9700 BAD, RX 6800 BAD. (I know, life isn't fair).
+
+Currently supported (but not necessary tested):
+
+**gfx110x**:
+
+* Radeon RX 7600
+* Radeon RX 7700 XT
+* Radeon RX 7800 XT
+* Radeon RX 7900 GRE
+* Radeon RX 7900 XT
+* Radeon RX 7900 XTX
+
+**gfx1151**:
+
+* Ryzen 7000 series APUs (Phoenix)
+* Ryzen Z1 (e.g., handheld devices like the ROG Ally)
+
+**gfx1201**:
+
+* Ryzen 8000 series APUs (Strix Point)
+* A [frame.work](https://frame.work/au/en/desktop) desktop/laptop
+
+
+## Requirements
+
+- Python 3.11 (3.12 might work, 3.10 definately will not!)
+
+## Installation Environment
+
+This installation uses PyTorch 2.7.0 because that's what currently available in
+terms of pre-compiled wheels.
+
+### Installing Python
+
+Download Python 3.11 from [python.org/downloads/windows](https://www.python.org/downloads/windows/). Hit Ctrl+F and search for "3.11". Dont use this direct link: [https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe) -- that was an IQ test.
+
+After installing, make sure `python --version` works in your terminal and returns 3.11.x
+
+If not, you probably need to fix your PATH. Go to:
+
+* Windows + Pause/Break
+* Advanced System Settings
+* Environment Variables
+* Edit your `Path` under User Variables
+
+Example correct entries:
+
+```cmd
+C:\Users\YOURNAME\AppData\Local\Programs\Python\Launcher\
+C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\Scripts\
+C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\
+```
+
+If that doesnt work, scream into a bucket.
+
+### Installing Git
+
+Get Git from [git-scm.com/downloads/win](https://git-scm.com/downloads/win). Default install is fine.
+
+
+## Install (Windows, using `venv`)
+
+### Step 1: Download and Set Up Environment
+
+```cmd
+:: Navigate to your desired install directory
+cd \your-path-to-wan2gp
+
+:: Clone the repository
+git clone https://github.com/deepbeepmeep/Wan2GP.git
+cd Wan2GP
+
+:: Create virtual environment using Python 3.10.9
+python -m venv wan2gp-env
+
+:: Activate the virtual environment
+wan2gp-env\Scripts\activate
+```
+
+### Step 2: Install PyTorch
+
+The pre-compiled wheels you need are hosted at [scottt's rocm-TheRock releases](https://github.com/scottt/rocm-TheRock/releases). Find the heading that says:
+
+**Pytorch wheels for gfx110x, gfx1151, and gfx1201**
+
+Don't click this link: [https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x](https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x). It's just here to check if you're skimming.
+
+Copy the links of the closest binaries to the ones in the example below (adjust if you're not running Python 3.11), then hit enter.
+
+```cmd
+pip install ^
+ https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl ^
+ https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl ^
+ https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl
+```
+
+### Step 3: Install Dependencies
+
+```cmd
+:: Install core dependencies
+pip install -r requirements.txt
+```
+
+## Attention Modes
+
+WanGP supports several attention implementations, only one of which will work for you:
+
+- **SDPA** (default): Available by default with PyTorch. This uses the built-in aotriton accel library, so is actually pretty fast.
+
+## Performance Profiles
+
+Choose a profile based on your hardware:
+
+- **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model
+- **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement
+
+## Running Wan2GP
+
+In future, you will have to do this:
+
+```cmd
+cd \path-to\wan2gp
+wan2gp\Scripts\activate.bat
+python wgp.py
+```
+
+For now, you should just be able to type `python wgp.py` (because you're already in the virtual environment)
+
+## Troubleshooting
+
+- If you use a HIGH VRAM mode, don't be a fool. Make sure you use VAE Tiled Decoding.
+
+### Memory Issues
+
+- Use lower resolution or shorter videos
+- Enable quantization (default)
+- Use Profile 4 for lower VRAM usage
+- Consider using 1.3B models instead of 14B models
+
+For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md)
From ee170f17659d2750103773c91a766595d6393cf5 Mon Sep 17 00:00:00 2001
From: Ciprian Mandache
Date: Wed, 9 Jul 2025 22:48:30 +0300
Subject: [PATCH 002/155] dockerize and add deb-based sys runner + builder
---
Dockerfile | 92 +++++++++++++++++++++
README.md | 113 +++++++++++++++++++------
entrypoint.sh | 18 ++++
run-docker-cuda-deb.sh | 183 +++++++++++++++++++++++++++++++++++++++++
4 files changed, 379 insertions(+), 27 deletions(-)
create mode 100644 Dockerfile
create mode 100755 entrypoint.sh
create mode 100755 run-docker-cuda-deb.sh
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 000000000..927c579fd
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,92 @@
+FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
+
+# Build arg for GPU architectures - specify which CUDA compute capabilities to compile for
+# Common values:
+# 7.0 - Tesla V100
+# 7.5 - RTX 2060, 2070, 2080, Titan RTX
+# 8.0 - A100, A800 (Ampere data center)
+# 8.6 - RTX 3060, 3070, 3080, 3090 (Ampere consumer)
+# 8.9 - RTX 4070, 4080, 4090 (Ada Lovelace)
+# 9.0 - H100, H800 (Hopper data center)
+# 12.0 - RTX 5070, 5080, 5090 (Blackwell) - Note: sm_120 architecture
+#
+# Examples:
+# RTX 3060: --build-arg CUDA_ARCHITECTURES="8.6"
+# RTX 4090: --build-arg CUDA_ARCHITECTURES="8.9"
+# Multiple: --build-arg CUDA_ARCHITECTURES="8.0;8.6;8.9"
+#
+# Note: Including 8.9 or 9.0 may cause compilation issues on some setups
+# Default includes 8.0 and 8.6 for broad Ampere compatibility
+ARG CUDA_ARCHITECTURES="8.0;8.6"
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+# Install system dependencies
+RUN apt update && \
+ apt install -y \
+ python3 python3-pip git wget curl cmake ninja-build \
+ libgl1 libglib2.0-0 ffmpeg && \
+ apt clean
+
+WORKDIR /workspace
+
+COPY requirements.txt .
+
+# Upgrade pip first
+RUN pip install --upgrade pip setuptools wheel
+
+# Install requirements if exists
+RUN pip install -r requirements.txt
+
+# Install PyTorch with CUDA support
+RUN pip install --extra-index-url https://download.pytorch.org/whl/cu124 \
+ torch==2.6.0+cu124 torchvision==0.21.0+cu124
+
+# Install SageAttention from git (patch GPU detection)
+ENV TORCH_CUDA_ARCH_LIST="${CUDA_ARCHITECTURES}"
+ENV FORCE_CUDA="1"
+ENV MAX_JOBS="1"
+
+COPY <WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor
WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with:
+
- Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models)
- Support for old GPUs (RTX 10XX, 20xx, ...)
- Very Fast on the latest GPUs
@@ -13,40 +15,46 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
- Auto download of the required model adapted to your specific architecture
- Tools integrated to facilitate Video Generation : Mask Editor, Prompt Enhancer, Temporal and Spatial Generation, MMAudio, Video Browser, Pose / Depth / Flow extractor
- Loras Support to customize each model
-- Queuing system : make your shopping list of videos to generate and come back later
+- Queuing system : make your shopping list of videos to generate and come back later
**Discord Server to get Help from Other Users and show your Best Videos:** https://discord.gg/g7efUW9jGV
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates
+
### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** :
-**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
-Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus.
+**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to _Sliding Windows_ support and _Adaptive Projected Guidance_ (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
+
+Of course you will get as well _Multitalk_ vanilla and also _Multitalk 720p_ as a bonus.
-And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people.
+And since I am mister nice guy I have enclosed as an exclusivity an _Audio Separator_ that will save you time to isolate each voice when using Multitalk with two people.
-As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors.
+As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the _Share Your Best Video_ channel your _Master Pieces_. The best ones will be added to the _Announcements Channel_ and will bring eternal fame to its authors.
But wait, there is more:
+
- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you)
- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes.
Also, of interest too:
+
- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos
- Force the generated video fps to your liking, works wery well with Vace when using a Control Video
- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time)
### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
+
- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
-- In one click use the newly generated video as a Control Video or Source Video to be continued
-- Manage multiple settings for the same model and switch between them using a dropdown box
+- In one click use the newly generated video as a Control Video or Source Video to be continued
+- Manage multiple settings for the same model and switch between them using a dropdown box
- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open
- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder)
Taking care of your life is not enough, you want new stuff to play with ?
-- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio
+
+- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the _Extensions_ tab of the WanGP _Configuration_ to enable MMAudio
- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best
- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps
- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage
@@ -58,24 +66,29 @@ Taking care of your life is not enough, you want new stuff to play with ?
**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ?
+
- Multithreaded preprocessing when possible for faster generations
- Multithreaded frames Lanczos Upsampling as a bonus
-- A new Vace preprocessor : *Flow* to extract fluid motion
-- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
+- A new Vace preprocessor : _Flow_ to extract fluid motion
+- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer _Human Movement_ and _Shapes_ at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
- Injected Frames Outpainting, in case you missed it in WanGP 6.21
Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated.
-
### June 19 2025: WanGP v6.2, Vace even more Powercharged
-👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power:
-- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time
-- More processing can combined at the same time (for instance the depth process can be applied outside the mask)
+
+👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power:
+
+- If you ever wanted to watch Star Wars in 4:3, just use the new _Outpainting_ feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is _Outpainting_ can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time
+- More processing can combined at the same time (for instance the depth process can be applied outside the mask)
- Upgraded the depth extractor to Depth Anything 2 which is much more detailed
-As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server.
+As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server.
+
### June 17 2025: WanGP v6.1, Vace Powercharged
+
👋 Lots of improvements for Vace the Mother of all Models:
+
- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask
- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ...
- view these modified masks directly inside WanGP during the video generation to check they are really as expected
@@ -87,52 +100,62 @@ Of course all these new stuff work on all Vace finetunes (including Vace Fusioni
Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary.
### June 12 2025: WanGP v6.0
-👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
+
+👋 _Finetune models_: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu):
-- *Fast Hunyuan Video* : generate model t2v in only 6 steps
-- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps
-- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps
+
+- _Fast Hunyuan Video_ : generate model t2v in only 6 steps
+- _Hunyuan Vido AccVideo_ : generate model t2v in only 5 steps
+- _Wan FusioniX_: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps
One more thing...
-The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ?
+The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ?
You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune...
-Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server.
+Check the _Finetune Guide_ to create finetune models definitions and share them on the WanGP discord server.
### June 11 2025: WanGP v5.5
-👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
-*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
+👋 _Hunyuan Video Custom Audio_: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
+_Hunyuan Video Custom Edit_: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
### June 6 2025: WanGP v5.41
+
👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\
-You will need to do a *pip install -r requirements.txt*
+You will need to do a _pip install -r requirements.txt_
### June 6 2025: WanGP v5.4
+
👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\
Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\
Also many thanks to Reevoy24 for his repackaging / completing the documentation
### May 28 2025: WanGP v5.31
+
👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets.
VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options.
### May 26, 2025: WanGP v5.3
+
👋 Settings management revolution! Now you can:
-- Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration
+
+- Select any generated video and click _Use Selected Video Settings_ to instantly reuse its configuration
- Drag & drop videos to automatically extract their settings metadata
- Export/import settings as JSON files for easy sharing and backup
### May 20, 2025: WanGP v5.2
+
👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid.
### May 18, 2025: WanGP v5.1
+
👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute!
### May 17, 2025: WanGP v5.0
+
👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer.
See full changelog: **[Changelog](docs/CHANGELOG.md)**
@@ -150,6 +173,7 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)**
**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/)
**Manual installation:**
+
```bash
git clone https://github.com/deepbeepmeep/Wan2GP.git
cd Wan2GP
@@ -160,6 +184,7 @@ pip install -r requirements.txt
```
**Run the application:**
+
```bash
python wgp.py # Text-to-video (default)
python wgp.py --i2v # Image-to-video
@@ -168,6 +193,7 @@ python wgp.py --i2v # Image-to-video
**Update the application:**
If using Pinokio use Pinokio to update otherwise:
Get in the directory where WanGP is installed and:
+
```bash
git pull
pip install -r requirements.txt
@@ -175,16 +201,48 @@ pip install -r requirements.txt
## 📦 Installation
+### 🐳 Docker Installation
+
+**For Debian-based systems (Ubuntu, Debian, etc.):**
+
+```bash
+./run-docker-cuda-deb.sh
+```
+
+This automated script will:
+
+- Detect your GPU model and VRAM automatically
+- Select optimal CUDA architecture for your GPU
+- Install NVIDIA Docker runtime if needed
+- Build a Docker image with all dependencies
+- Run WanGP with optimal settings for your hardware
+
+**Docker environment includes:**
+
+- NVIDIA CUDA 12.4.1 with cuDNN support
+- PyTorch 2.6.0 with CUDA 12.4 support
+- SageAttention compiled for your specific GPU architecture
+- Optimized environment variables for performance (TF32, threading, etc.)
+- Automatic cache directory mounting for faster subsequent runs
+- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally
+
+**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more.
+
+### Manual Installation
+
For detailed installation instructions for different GPU generations:
+
- **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX
## 🎯 Usage
### Basic Usage
+
- **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage
- **[Models Overview](docs/MODELS.md)** - Available models and their capabilities
### Advanced Features
+
- **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization
- **[Finetunes](docs/FINETUNES.md)** - Add manually new models to WanGP
- **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation
@@ -198,6 +256,7 @@ For detailed installation instructions for different GPU generations:
## 🔗 Related Projects
### Other Models for the GPU Poor
+
- **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators
- **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool
- **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux
@@ -209,4 +268,4 @@ For detailed installation instructions for different GPU generations:
Made with ❤️ by DeepBeepMeep
-
+
diff --git a/entrypoint.sh b/entrypoint.sh
new file mode 100755
index 000000000..4c363abd1
--- /dev/null
+++ b/entrypoint.sh
@@ -0,0 +1,18 @@
+#!/usr/bin/env bash
+export HOME=/home/user
+export PYTHONUNBUFFERED=1
+export HF_HOME=/home/user/.cache/huggingface
+
+export OMP_NUM_THREADS=$(nproc)
+export MKL_NUM_THREADS=$(nproc)
+export OPENBLAS_NUM_THREADS=$(nproc)
+export NUMEXPR_NUM_THREADS=$(nproc)
+
+export TORCH_ALLOW_TF32_CUBLAS=1
+export TORCH_ALLOW_TF32_CUDNN=1
+
+# Disable audio warnings in Docker
+export SDL_AUDIODRIVER=dummy
+export PULSE_RUNTIME_PATH=/tmp/pulse-runtime
+
+exec su -p user -c "python3 wgp.py --listen $*"
diff --git a/run-docker-cuda-deb.sh b/run-docker-cuda-deb.sh
new file mode 100755
index 000000000..1a6201aa4
--- /dev/null
+++ b/run-docker-cuda-deb.sh
@@ -0,0 +1,183 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+# ───────────────────────── helpers ─────────────────────────
+
+install_nvidia_smi_if_missing() {
+ if command -v nvidia-smi &>/dev/null; then
+ return
+ fi
+
+ echo "⚠️ nvidia-smi not found. Installing nvidia-utils…"
+ if [ "$EUID" -ne 0 ]; then
+ SUDO='sudo'
+ else
+ SUDO=''
+ fi
+
+ $SUDO apt-get update
+ $SUDO apt-get install -y nvidia-utils-535 || $SUDO apt-get install -y nvidia-utils
+
+ if ! command -v nvidia-smi &>/dev/null; then
+ echo "❌ Failed to install nvidia-smi. Cannot detect GPU architecture."
+ exit 1
+ fi
+ echo "✅ nvidia-smi installed successfully."
+}
+
+detect_gpu_name() {
+ install_nvidia_smi_if_missing
+ nvidia-smi --query-gpu=name --format=csv,noheader,nounits | head -1
+}
+
+map_gpu_to_arch() {
+ local name="$1"
+ case "$name" in
+ *"RTX 50"* | *"5090"* | *"5080"* | *"5070"*) echo "12.0" ;;
+ *"H100"* | *"H800"*) echo "9.0" ;;
+ *"RTX 40"* | *"4090"* | *"4080"* | *"4070"* | *"4060"*) echo "8.9" ;;
+ *"RTX 30"* | *"3090"* | *"3080"* | *"3070"* | *"3060"*) echo "8.6" ;;
+ *"A100"* | *"A800"* | *"A40"*) echo "8.0" ;;
+ *"Tesla V100"*) echo "7.0" ;;
+ *"RTX 20"* | *"2080"* | *"2070"* | *"2060"* | *"Titan RTX"*) echo "7.5" ;;
+ *"GTX 16"* | *"1660"* | *"1650"*) echo "7.5" ;;
+ *"GTX 10"* | *"1080"* | *"1070"* | *"1060"* | *"Tesla P100"*) echo "6.1" ;;
+ *"Tesla K80"* | *"Tesla K40"*) echo "3.7" ;;
+ *)
+ echo "❌ Unknown GPU model: $name"
+ echo "Please update the map_gpu_to_arch function for this model."
+ exit 1
+ ;;
+ esac
+}
+
+get_gpu_vram() {
+ install_nvidia_smi_if_missing
+ # Get VRAM in MB, convert to GB
+ local vram_mb=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | head -1)
+ echo $((vram_mb / 1024))
+}
+
+map_gpu_to_profile() {
+ local name="$1"
+ local vram_gb="$2"
+
+ # WanGP Profile descriptions from the actual UI:
+ # Profile 1: HighRAM_HighVRAM - 48GB+ RAM, 24GB+ VRAM (fastest for short videos, RTX 3090/4090)
+ # Profile 2: HighRAM_LowVRAM - 48GB+ RAM, 12GB+ VRAM (recommended, most versatile)
+ # Profile 3: LowRAM_HighVRAM - 32GB+ RAM, 24GB+ VRAM (RTX 3090/4090 with limited RAM)
+ # Profile 4: LowRAM_LowVRAM - 32GB+ RAM, 12GB+ VRAM (default, little VRAM or longer videos)
+ # Profile 5: VerylowRAM_LowVRAM - 16GB+ RAM, 10GB+ VRAM (fail safe, slow but works)
+
+ case "$name" in
+ # High-end data center GPUs with 24GB+ VRAM - Profile 1 (HighRAM_HighVRAM)
+ *"RTX 50"* | *"5090"* | *"A100"* | *"A800"* | *"H100"* | *"H800"*)
+ if [ "$vram_gb" -ge 24 ]; then
+ echo "1" # HighRAM_HighVRAM - fastest for short videos
+ else
+ echo "2" # HighRAM_LowVRAM - most versatile
+ fi
+ ;;
+ # High-end consumer GPUs (RTX 3090/4090) - Profile 1 or 3
+ *"RTX 40"* | *"4090"* | *"RTX 30"* | *"3090"*)
+ if [ "$vram_gb" -ge 24 ]; then
+ echo "3" # LowRAM_HighVRAM - good for limited RAM systems
+ else
+ echo "2" # HighRAM_LowVRAM - most versatile
+ fi
+ ;;
+ # Mid-range GPUs (RTX 3070/3080/4070/4080) - Profile 2 recommended
+ *"4080"* | *"4070"* | *"3080"* | *"3070"* | *"RTX 20"* | *"2080"* | *"2070"*)
+ if [ "$vram_gb" -ge 12 ]; then
+ echo "2" # HighRAM_LowVRAM - recommended for these GPUs
+ else
+ echo "4" # LowRAM_LowVRAM - default for little VRAM
+ fi
+ ;;
+ # Lower-end GPUs with 6-12GB VRAM - Profile 4 or 5
+ *"4060"* | *"3060"* | *"2060"* | *"GTX 16"* | *"1660"* | *"1650"*)
+ if [ "$vram_gb" -ge 10 ]; then
+ echo "4" # LowRAM_LowVRAM - default
+ else
+ echo "5" # VerylowRAM_LowVRAM - fail safe
+ fi
+ ;;
+ # Older/lower VRAM GPUs - Profile 5 (fail safe)
+ *"GTX 10"* | *"1080"* | *"1070"* | *"1060"* | *"Tesla"*)
+ echo "5" # VerylowRAM_LowVRAM - fail safe
+ ;;
+ *)
+ echo "4" # LowRAM_LowVRAM - default fallback
+ ;;
+ esac
+}
+
+# ───────────────────────── main ────────────────────────────
+
+GPU_NAME=$(detect_gpu_name)
+echo "🔍 Detected GPU: $GPU_NAME"
+
+VRAM_GB=$(get_gpu_vram)
+echo "🧠 Detected VRAM: ${VRAM_GB}GB"
+
+CUDA_ARCH=$(map_gpu_to_arch "$GPU_NAME")
+echo "🚀 Using CUDA architecture: $CUDA_ARCH"
+
+PROFILE=$(map_gpu_to_profile "$GPU_NAME" "$VRAM_GB")
+echo "⚙️ Selected profile: $PROFILE"
+
+docker build --build-arg CUDA_ARCHITECTURES="$CUDA_ARCH" -t deepbeepmeep/wan2gp .
+
+# sudo helper for later commands
+if [ "$EUID" -ne 0 ]; then
+ SUDO='sudo'
+else
+ SUDO=''
+fi
+
+# Ensure NVIDIA runtime is available
+if ! docker info 2>/dev/null | grep -q 'Runtimes:.*nvidia'; then
+ echo "⚠️ NVIDIA Docker runtime not found. Installing nvidia-docker2…"
+ $SUDO apt-get update
+ $SUDO apt-get install -y curl ca-certificates gnupg
+ curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | $SUDO apt-key add -
+ distribution=$(
+ . /etc/os-release
+ echo $ID$VERSION_ID
+ )
+ curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list |
+ $SUDO tee /etc/apt/sources.list.d/nvidia-docker.list
+ $SUDO apt-get update
+ $SUDO apt-get install -y nvidia-docker2
+ echo "🔄 Restarting Docker service…"
+ $SUDO systemctl restart docker
+ echo "✅ NVIDIA Docker runtime installed."
+else
+ echo "✅ NVIDIA Docker runtime found."
+fi
+
+# Prepare cache dirs & volume mounts
+cache_dirs=(numba matplotlib huggingface torch)
+cache_mounts=()
+for d in "${cache_dirs[@]}"; do
+ mkdir -p "$HOME/.cache/$d"
+ chmod 700 "$HOME/.cache/$d"
+ cache_mounts+=(-v "$HOME/.cache/$d:/home/user/.cache/$d")
+done
+
+echo "🔧 Optimization settings:"
+echo " Profile: $PROFILE"
+
+# Run the container
+docker run --rm -it \
+ --name wan2gp \
+ --gpus all \
+ --runtime=nvidia \
+ -p 7860:7860 \
+ -v "$(pwd):/workspace" \
+ "${cache_mounts[@]}" \
+ deepbeepmeep/wan2gp \
+ --profile "$PROFILE" \
+ --attention sage \
+ --compile \
+ --perc-reserved-mem-max 1
From 3307defa7c005e8566170d0784254aa9829a1759 Mon Sep 17 00:00:00 2001
From: Ciprian Mandache
Date: Sat, 30 Aug 2025 06:51:17 +0300
Subject: [PATCH 003/155] add more checks + logging + upgrade peft
---
entrypoint.sh | 100 +++++++++++++++++++++++++++++++++++++++++
requirements.txt | 8 ++--
run-docker-cuda-deb.sh | 27 +++++++++++
3 files changed, 131 insertions(+), 4 deletions(-)
diff --git a/entrypoint.sh b/entrypoint.sh
index 4c363abd1..9af052d07 100755
--- a/entrypoint.sh
+++ b/entrypoint.sh
@@ -15,4 +15,104 @@ export TORCH_ALLOW_TF32_CUDNN=1
export SDL_AUDIODRIVER=dummy
export PULSE_RUNTIME_PATH=/tmp/pulse-runtime
+# ═══════════════════════════ CUDA DEBUG CHECKS ═══════════════════════════
+
+echo "🔍 CUDA Environment Debug Information:"
+echo "═══════════════════════════════════════════════════════════════════════"
+
+# Check CUDA driver on host (if accessible)
+if command -v nvidia-smi >/dev/null 2>&1; then
+ echo "✅ nvidia-smi available"
+ echo "📊 GPU Information:"
+ nvidia-smi --query-gpu=name,driver_version,memory.total,memory.free --format=csv,noheader,nounits 2>/dev/null || echo "❌ nvidia-smi failed to query GPU"
+ echo "🏃 Running Processes:"
+ nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader,nounits 2>/dev/null || echo "ℹ️ No running CUDA processes"
+else
+ echo "❌ nvidia-smi not available in container"
+fi
+
+# Check CUDA runtime libraries
+echo ""
+echo "🔧 CUDA Runtime Check:"
+if ls /usr/local/cuda*/lib*/libcudart.so* >/dev/null 2>&1; then
+ echo "✅ CUDA runtime libraries found:"
+ ls /usr/local/cuda*/lib*/libcudart.so* 2>/dev/null
+else
+ echo "❌ CUDA runtime libraries not found"
+fi
+
+# Check CUDA devices
+echo ""
+echo "🖥️ CUDA Device Files:"
+if ls /dev/nvidia* >/dev/null 2>&1; then
+ echo "✅ NVIDIA device files found:"
+ ls -la /dev/nvidia* 2>/dev/null
+else
+ echo "❌ No NVIDIA device files found - Docker may not have GPU access"
+fi
+
+# Check CUDA environment variables
+echo ""
+echo "🌍 CUDA Environment Variables:"
+echo " CUDA_HOME: ${CUDA_HOME:-not set}"
+echo " CUDA_ROOT: ${CUDA_ROOT:-not set}"
+echo " CUDA_PATH: ${CUDA_PATH:-not set}"
+echo " LD_LIBRARY_PATH: ${LD_LIBRARY_PATH:-not set}"
+echo " TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-not set}"
+echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-not set}"
+
+# Check PyTorch CUDA availability
+echo ""
+echo "🐍 PyTorch CUDA Check:"
+python3 -c "
+import sys
+try:
+ import torch
+ print('✅ PyTorch imported successfully')
+ print(f' Version: {torch.__version__}')
+ print(f' CUDA available: {torch.cuda.is_available()}')
+ if torch.cuda.is_available():
+ print(f' CUDA version: {torch.version.cuda}')
+ print(f' cuDNN version: {torch.backends.cudnn.version()}')
+ print(f' Device count: {torch.cuda.device_count()}')
+ for i in range(torch.cuda.device_count()):
+ props = torch.cuda.get_device_properties(i)
+ print(f' Device {i}: {props.name} (SM {props.major}.{props.minor}, {props.total_memory//1024//1024}MB)')
+ else:
+ print('❌ CUDA not available to PyTorch')
+ print(' This could mean:')
+ print(' - CUDA runtime not properly installed')
+ print(' - GPU not accessible to container')
+ print(' - Driver/runtime version mismatch')
+except ImportError as e:
+ print(f'❌ Failed to import PyTorch: {e}')
+except Exception as e:
+ print(f'❌ PyTorch CUDA check failed: {e}')
+" 2>&1
+
+# Check for common CUDA issues
+echo ""
+echo "🩺 Common Issue Diagnostics:"
+
+# Check if running with proper Docker flags
+if [ ! -e /dev/nvidia0 ] && [ ! -e /dev/nvidiactl ]; then
+ echo "❌ No NVIDIA device nodes - container likely missing --gpus all or --runtime=nvidia"
+fi
+
+# Check CUDA library paths
+if [ -z "$LD_LIBRARY_PATH" ] || ! echo "$LD_LIBRARY_PATH" | grep -q cuda; then
+ echo "⚠️ LD_LIBRARY_PATH may not include CUDA libraries"
+fi
+
+# Check permissions on device files
+if ls /dev/nvidia* >/dev/null 2>&1; then
+ if ! ls -la /dev/nvidia* | grep -q "rw-rw-rw-\|rw-r--r--"; then
+ echo "⚠️ NVIDIA device files may have restrictive permissions"
+ fi
+fi
+
+echo "═══════════════════════════════════════════════════════════════════════"
+echo "🚀 Starting application..."
+echo ""
+
exec su -p user -c "python3 wgp.py --listen $*"
diff --git a/requirements.txt b/requirements.txt
index ea0c582f6..312dc9400 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,13 +12,13 @@ easydict
ftfy
dashscope
imageio-ffmpeg
-# flash_attn
-gradio==5.23.0
+# flash_attn
+gradio==5.23.0
numpy>=1.23.5,<2
einops
moviepy==1.0.3
mmgp==3.5.1
-peft==0.15.0
+peft==0.17.0
mutagen
pydantic==2.10.6
decord
@@ -46,4 +46,4 @@ soundfile
ffmpeg-python
pyannote.audio
# num2words
-# spacy
\ No newline at end of file
+# spacy
diff --git a/run-docker-cuda-deb.sh b/run-docker-cuda-deb.sh
index 1a6201aa4..b35e9cc6b 100755
--- a/run-docker-cuda-deb.sh
+++ b/run-docker-cuda-deb.sh
@@ -114,6 +114,25 @@ map_gpu_to_profile() {
# ───────────────────────── main ────────────────────────────
+echo "🔧 NVIDIA CUDA Setup Check:"
+
+# NVIDIA driver check
+if command -v nvidia-smi &>/dev/null; then
+ DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader,nounits | head -1)
+ echo "✅ NVIDIA Driver: $DRIVER_VERSION"
+
+ # Quick CUDA 12.4 compatibility check
+ if [[ "$DRIVER_VERSION" =~ ^([0-9]+) ]]; then
+ MAJOR=${BASH_REMATCH[1]}
+ if [ "$MAJOR" -lt 520 ]; then
+ echo "⚠️ Driver $DRIVER_VERSION may not support CUDA 12.4 (need 520+)"
+ fi
+ fi
+else
+ echo "❌ nvidia-smi not found - no NVIDIA drivers"
+ exit 1
+fi
+
GPU_NAME=$(detect_gpu_name)
echo "🔍 Detected GPU: $GPU_NAME"
@@ -156,6 +175,14 @@ else
echo "✅ NVIDIA Docker runtime found."
fi
+# Quick NVIDIA runtime test
+echo "🧪 Testing NVIDIA runtime..."
+if timeout 15s docker run --rm --gpus all --runtime=nvidia nvidia/cuda:12.4-runtime-ubuntu22.04 nvidia-smi >/dev/null 2>&1; then
+ echo "✅ NVIDIA runtime working"
+else
+ echo "❌ NVIDIA runtime test failed - check driver/runtime compatibility"
+fi
+
# Prepare cache dirs & volume mounts
cache_dirs=(numba matplotlib huggingface torch)
cache_mounts=()
From 808ef51688dfd30288c3a99f369448d0e76b9a1f Mon Sep 17 00:00:00 2001
From: Ciprian Mandache
Date: Sat, 30 Aug 2025 07:02:35 +0300
Subject: [PATCH 004/155] fix readme
---
README.md | 133 ++++++++++++++++++++----------------------------------
1 file changed, 49 insertions(+), 84 deletions(-)
diff --git a/README.md b/README.md
index 50c03ee2e..fb04ce13d 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,11 @@
# WanGP
----
-
+-----
WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor
WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with:
-
- Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models)
- Support for old GPUs (RTX 10XX, 20xx, ...)
- Very Fast on the latest GPUs
@@ -45,39 +43,31 @@ Also in the news:
### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.
-### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
-
-Maybe you knew that already but most _Loras accelerators_ we use today (Causvid, FusioniX) don't use _Guidance_ at all (that it is _CFG_ is set to 1). This helps to get much faster generations but the downside is that _Negative Prompts_ are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the _Negative Prompt_ during the _attention_ processing phase.
-
-So WanGP 6.7 gives you NAG, but not any NAG, a _Low VRAM_ implementation, the default one ends being VRAM greedy. You will find NAG in the _General_ advanced tab for most Wan models.
+So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models.
Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters.
The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence.
### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** :
+**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
-**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to _Sliding Windows_ support and _Adaptive Projected Guidance_ (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
+Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus.
-Of course you will get as well _Multitalk_ vanilla and also _Multitalk 720p_ as a bonus.
+And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people.
-And since I am mister nice guy I have enclosed as an exclusivity an _Audio Separator_ that will save you time to isolate each voice when using Multitalk with two people.
-
-As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the _Share Your Best Video_ channel your _Master Pieces_. The best ones will be added to the _Announcements Channel_ and will bring eternal fame to its authors.
+As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors.
But wait, there is more:
-
- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you)
- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes.
Also, of interest too:
-
- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos
- Force the generated video fps to your liking, works wery well with Vace when using a Control Video
- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time)
### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
-
- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
- In one click use the newly generated video as a Control Video or Source Video to be continued
- Manage multiple settings for the same model and switch between them using a dropdown box
@@ -85,8 +75,7 @@ Also, of interest too:
- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder)
Taking care of your life is not enough, you want new stuff to play with ?
-
-- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the _Extensions_ tab of the WanGP _Configuration_ to enable MMAudio
+- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio
- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best
- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps
- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage
@@ -98,29 +87,24 @@ Taking care of your life is not enough, you want new stuff to play with ?
**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ?
-
- Multithreaded preprocessing when possible for faster generations
- Multithreaded frames Lanczos Upsampling as a bonus
-- A new Vace preprocessor : _Flow_ to extract fluid motion
-- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer _Human Movement_ and _Shapes_ at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
+- A new Vace preprocessor : *Flow* to extract fluid motion
+- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character.
- Injected Frames Outpainting, in case you missed it in WanGP 6.21
Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated.
-### June 19 2025: WanGP v6.2, Vace even more Powercharged
+### June 19 2025: WanGP v6.2, Vace even more Powercharged
👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power:
-
-- If you ever wanted to watch Star Wars in 4:3, just use the new _Outpainting_ feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is _Outpainting_ can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time
-- More processing can combined at the same time (for instance the depth process can be applied outside the mask)
+- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time
+- More processing can combined at the same time (for instance the depth process can be applied outside the mask)
- Upgraded the depth extractor to Depth Anything 2 which is much more detailed
As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server.
-
### June 17 2025: WanGP v6.1, Vace Powercharged
-
👋 Lots of improvements for Vace the Mother of all Models:
-
- masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask
- on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ...
- view these modified masks directly inside WanGP during the video generation to check they are really as expected
@@ -132,62 +116,52 @@ Of course all these new stuff work on all Vace finetunes (including Vace Fusioni
Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary.
### June 12 2025: WanGP v6.0
-
-👋 _Finetune models_: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
+👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them.
To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu):
-
-- _Fast Hunyuan Video_ : generate model t2v in only 6 steps
-- _Hunyuan Vido AccVideo_ : generate model t2v in only 5 steps
-- _Wan FusioniX_: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps
+- *Fast Hunyuan Video* : generate model t2v in only 6 steps
+- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps
+- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps
One more thing...
-The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ?
+The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ?
You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune...
-Check the _Finetune Guide_ to create finetune models definitions and share them on the WanGP discord server.
+Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server.
### June 11 2025: WanGP v5.5
+👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
+*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
-👋 _Hunyuan Video Custom Audio_: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\
-_Hunyuan Video Custom Edit_: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content...
### June 6 2025: WanGP v5.41
-
👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\
-You will need to do a _pip install -r requirements.txt_
+You will need to do a *pip install -r requirements.txt*
### June 6 2025: WanGP v5.4
-
👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\
Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\
Also many thanks to Reevoy24 for his repackaging / completing the documentation
### May 28 2025: WanGP v5.31
-
👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets.
VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options.
### May 26, 2025: WanGP v5.3
-
👋 Settings management revolution! Now you can:
-
-- Select any generated video and click _Use Selected Video Settings_ to instantly reuse its configuration
+- Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration
- Drag & drop videos to automatically extract their settings metadata
- Export/import settings as JSON files for easy sharing and backup
### May 20, 2025: WanGP v5.2
-
👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid.
### May 18, 2025: WanGP v5.1
-
👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute!
### May 17, 2025: WanGP v5.0
-
👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer.
See full changelog: **[Changelog](docs/CHANGELOG.md)**
@@ -202,10 +176,36 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)**
## 🚀 Quick Start
+### 🐳 Docker:
+
+**For Debian-based systems (Ubuntu, Debian, etc.):**
+
+```bash
+./run-docker-cuda-deb.sh
+```
+
+This automated script will:
+
+- Detect your GPU model and VRAM automatically
+- Select optimal CUDA architecture for your GPU
+- Install NVIDIA Docker runtime if needed
+- Build a Docker image with all dependencies
+- Run WanGP with optimal settings for your hardware
+
+**Docker environment includes:**
+
+- NVIDIA CUDA 12.4.1 with cuDNN support
+- PyTorch 2.6.0 with CUDA 12.4 support
+- SageAttention compiled for your specific GPU architecture
+- Optimized environment variables for performance (TF32, threading, etc.)
+- Automatic cache directory mounting for faster subsequent runs
+- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally
+
+**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more.
+
**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/)
**Manual installation:**
-
```bash
git clone https://github.com/deepbeepmeep/Wan2GP.git
cd Wan2GP
@@ -216,7 +216,6 @@ pip install -r requirements.txt
```
**Run the application:**
-
```bash
python wgp.py # Text-to-video (default)
python wgp.py --i2v # Image-to-video
@@ -225,7 +224,6 @@ python wgp.py --i2v # Image-to-video
**Update the application:**
If using Pinokio use Pinokio to update otherwise:
Get in the directory where WanGP is installed and:
-
```bash
git pull
pip install -r requirements.txt
@@ -233,48 +231,16 @@ pip install -r requirements.txt
## 📦 Installation
-### 🐳 Docker Installation
-
-**For Debian-based systems (Ubuntu, Debian, etc.):**
-
-```bash
-./run-docker-cuda-deb.sh
-```
-
-This automated script will:
-
-- Detect your GPU model and VRAM automatically
-- Select optimal CUDA architecture for your GPU
-- Install NVIDIA Docker runtime if needed
-- Build a Docker image with all dependencies
-- Run WanGP with optimal settings for your hardware
-
-**Docker environment includes:**
-
-- NVIDIA CUDA 12.4.1 with cuDNN support
-- PyTorch 2.6.0 with CUDA 12.4 support
-- SageAttention compiled for your specific GPU architecture
-- Optimized environment variables for performance (TF32, threading, etc.)
-- Automatic cache directory mounting for faster subsequent runs
-- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally
-
-**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more.
-
-### Manual Installation
-
For detailed installation instructions for different GPU generations:
-
- **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX
## 🎯 Usage
### Basic Usage
-
- **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage
- **[Models Overview](docs/MODELS.md)** - Available models and their capabilities
### Advanced Features
-
- **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization
- **[Finetunes](docs/FINETUNES.md)** - Add manually new models to WanGP
- **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation
@@ -288,7 +254,6 @@ For detailed installation instructions for different GPU generations:
## 🔗 Related Projects
### Other Models for the GPU Poor
-
- **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators
- **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool
- **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux
From 1ae91d36f1d7135e326b8c1129584482c69df44c Mon Sep 17 00:00:00 2001
From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com>
Date: Tue, 2 Sep 2025 02:50:49 +0200
Subject: [PATCH 005/155] fixed InfiniteTalk starting from last frame
---
wgp.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/wgp.py b/wgp.py
index a151252ba..0345e642e 100644
--- a/wgp.py
+++ b/wgp.py
@@ -4739,7 +4739,7 @@ def remove_temp_filenames(temp_filenames_list):
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
if infinitetalk and video_guide is not None:
- src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True)
+ src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
refresh_preview["video_guide"] = src_image
@@ -9290,4 +9290,4 @@ def create_ui():
else:
url = "http://" + server_name
webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True)
- demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path}))
\ No newline at end of file
+ demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path}))
From 898b542cc6df6246b614a5ce93135482517062ff Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 2 Sep 2025 20:39:31 +0200
Subject: [PATCH 006/155] pain reliever
---
README.md | 11 +
models/hyvideo/hunyuan_handler.py | 2 +-
models/hyvideo/modules/utils.py | 2 +-
models/qwen/transformer_qwenimage.py | 3 +-
models/wan/any2video.py | 11 +-
models/wan/modules/model.py | 27 +-
models/wan/multitalk/attention.py | 26 +-
models/wan/multitalk/multitalk_utils.py | 111 ++++--
models/wan/wan_handler.py | 22 ++
requirements.txt | 5 +-
shared/attention.py | 73 +++-
shared/gradio/gallery.py | 496 ++++++++++++++++++++++++
shared/utils/qwen_vl_utils.py | 2 -
wgp.py | 190 ++++-----
14 files changed, 799 insertions(+), 182 deletions(-)
create mode 100644 shared/gradio/gallery.py
diff --git a/README.md b/README.md
index 47736da47..be76cbcc8 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,17 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
+### September 2 2025: WanGP v8.3 - At last the pain stops
+
+- This single new feature should give you the strength to face all the potential bugs of this new release:
+**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.**
+
+- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p).
+
+- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash.
+
+
+
### August 29 2025: WanGP v8.21 - Here Goes Your Weekend
- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py
index d95bd7e35..da60f72cb 100644
--- a/models/hyvideo/hunyuan_handler.py
+++ b/models/hyvideo/hunyuan_handler.py
@@ -53,7 +53,7 @@ def query_model_def(base_model_type, model_def):
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
- if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]:
+ if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
extra_model_def["one_image_ref_needed"] = True
return extra_model_def
diff --git a/models/hyvideo/modules/utils.py b/models/hyvideo/modules/utils.py
index 02a733e1b..c26399728 100644
--- a/models/hyvideo/modules/utils.py
+++ b/models/hyvideo/modules/utils.py
@@ -14,7 +14,7 @@
)
-@lru_cache
+# @lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
return block_mask
diff --git a/models/qwen/transformer_qwenimage.py b/models/qwen/transformer_qwenimage.py
index 6d908064a..8ec34ed01 100644
--- a/models/qwen/transformer_qwenimage.py
+++ b/models/qwen/transformer_qwenimage.py
@@ -204,7 +204,7 @@ def forward(self, video_fhw, txt_seq_lens, device):
frame, height, width = fhw
rope_key = f"{idx}_{height}_{width}"
- if not torch.compiler.is_compiling():
+ if not torch.compiler.is_compiling() and False:
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key]
@@ -224,7 +224,6 @@ def forward(self, video_fhw, txt_seq_lens, device):
return vid_freqs, txt_freqs
- @functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index ee58c0a09..8eb2ffef2 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -19,7 +19,8 @@
import torchvision.transforms.functional as TF
import torch.nn.functional as F
from .distributed.fsdp import shard_model
-from .modules.model import WanModel, clear_caches
+from .modules.model import WanModel
+from mmgp.offload import get_cache, clear_caches
from .modules.t5 import T5EncoderModel
from .modules.vae import WanVAE
from .modules.vae2_2 import Wan2_2_VAE
@@ -496,6 +497,8 @@ def generate(self,
text_len = self.model.text_len
context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0)
context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0)
+ if input_video is not None: height, width = input_video.shape[-2:]
+
# NAG_prompt = "static, low resolution, blurry"
# context_NAG = self.text_encoder([NAG_prompt], self.device)[0]
# context_NAG = context_NAG.to(self.dtype)
@@ -530,9 +533,10 @@ def generate(self,
if image_start is None:
if infinitetalk:
if input_frames is not None:
- image_ref = input_frames[:, -1]
- if input_video is None: input_video = input_frames[:, -1:]
+ image_ref = input_frames[:, 0]
+ if input_video is None: input_video = input_frames[:, 0:1]
new_shot = "Q" in video_prompt_type
+ denoising_strength = 0.5
else:
if pre_video_frame is None:
new_shot = True
@@ -888,6 +892,7 @@ def clear():
latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents
else:
latent_noise_factor = t / 1000
+ latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor
if vace:
overlap_noise_factor = overlap_noise / 1000
for zz in z:
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index b9c98bf81..d3dc783e3 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -12,6 +12,7 @@
import numpy as np
from typing import Union,Optional
from mmgp import offload
+from mmgp.offload import get_cache, clear_caches
from shared.attention import pay_attention
from torch.backends.cuda import sdp_kernel
from ..multitalk.multitalk_utils import get_attn_map_with_target
@@ -19,22 +20,6 @@
__all__ = ['WanModel']
-def get_cache(cache_name):
- all_cache = offload.shared_state.get("_cache", None)
- if all_cache is None:
- all_cache = {}
- offload.shared_state["_cache"]= all_cache
- cache = offload.shared_state.get(cache_name, None)
- if cache is None:
- cache = {}
- offload.shared_state[cache_name] = cache
- return cache
-
-def clear_caches():
- all_cache = offload.shared_state.get("_cache", None)
- if all_cache is not None:
- all_cache.clear()
-
def sinusoidal_embedding_1d(dim, position):
# preprocess
assert dim % 2 == 0
@@ -579,19 +564,23 @@ def forward(
y = self.norm_x(x)
y = y.to(attention_dtype)
if ref_images_count == 0:
- x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
+ ylist= [y]
+ del y
+ x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map)
else:
y_shape = y.shape
y = y.reshape(y_shape[0], grid_sizes[0], -1)
y = y[:, ref_images_count:]
y = y.reshape(y_shape[0], -1, y_shape[-1])
grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]]
- y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
+ ylist= [y]
+ y = None
+ y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map)
y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1)
x = x.reshape(y_shape[0], grid_sizes[0], -1)
x[:, ref_images_count:] += y
x = x.reshape(y_shape[0], -1, y_shape[-1])
- del y
+ del y
y = self.norm2(x)
diff --git a/models/wan/multitalk/attention.py b/models/wan/multitalk/attention.py
index 27d488f6d..669b5c1a7 100644
--- a/models/wan/multitalk/attention.py
+++ b/models/wan/multitalk/attention.py
@@ -221,13 +221,16 @@ def __init__(
self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
- def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
+ def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor:
N_t, N_h, N_w = shape
+ x = xlist[0]
+ xlist.clear()
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state
B, N, C = x.shape
q = self.q_linear(x)
+ del x
q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3))
@@ -247,9 +250,6 @@ def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=No
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
-
- attn_bias = None
- # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,)
qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None
x = pay_attention(qkv_list)
@@ -302,7 +302,7 @@ def __init__(
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
def forward(self,
- x: torch.Tensor,
+ xlist: torch.Tensor,
encoder_hidden_states: torch.Tensor,
shape=None,
x_ref_attn_map=None,
@@ -310,14 +310,17 @@ def forward(self,
encoder_hidden_states = encoder_hidden_states.squeeze(0)
if x_ref_attn_map == None:
- return super().forward(x, encoder_hidden_states, shape)
+ return super().forward(xlist, encoder_hidden_states, shape)
N_t, _, _ = shape
+ x = xlist[0]
+ xlist.clear()
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
# get q for hidden_state
B, N, C = x.shape
q = self.q_linear(x)
+ del x
q_shape = (B, N, self.num_heads, self.head_dim)
q = q.view(q_shape).permute((0, 2, 1, 3))
@@ -339,7 +342,9 @@ def forward(self,
normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
- q = self.rope_1d(q, normalized_pos)
+ qlist = [q]
+ del q
+ q = self.rope_1d(qlist, normalized_pos, "q")
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
_, N_a, _ = encoder_hidden_states.shape
@@ -347,7 +352,7 @@ def forward(self,
encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim)
encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4))
encoder_k, encoder_v = encoder_kv.unbind(0)
-
+ del encoder_kv
if self.qk_norm:
encoder_k = self.add_k_norm(encoder_k)
@@ -356,13 +361,14 @@ def forward(self,
per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2
encoder_pos = torch.concat([per_frame]*N_t, dim=0)
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
- encoder_k = self.rope_1d(encoder_k, encoder_pos)
+ enclist = [encoder_k]
+ del encoder_k
+ encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k")
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
q = rearrange(q, "B H M K -> B M H K")
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
- # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,)
qkv_list = [q, encoder_k, encoder_v]
q = encoder_k = encoder_v = None
x = pay_attention(qkv_list)
diff --git a/models/wan/multitalk/multitalk_utils.py b/models/wan/multitalk/multitalk_utils.py
index 6e2b2c307..7851c452b 100644
--- a/models/wan/multitalk/multitalk_utils.py
+++ b/models/wan/multitalk/multitalk_utils.py
@@ -16,7 +16,7 @@
import binascii
import os.path as osp
from skimage import color
-
+from mmgp.offload import get_cache, clear_caches
VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
ASPECT_RATIO_627 = {
@@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
# @torch.compile
-def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None):
-
- ref_k = ref_k.to(visual_q.dtype).to(visual_q.device)
+def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None):
+ dtype = visual_q.dtype
+ ref_k = ref_k.to(dtype).to(visual_q.device)
scale = 1.0 / visual_q.shape[-1] ** 0.5
visual_q = visual_q * scale
visual_q = visual_q.transpose(1, 2)
ref_k = ref_k.transpose(1, 2)
- attn = visual_q @ ref_k.transpose(-2, -1)
+ visual_q_shape = visual_q.shape
+ visual_q = visual_q.view(-1, visual_q_shape[-1] )
+ number_chunks = visual_q_shape[-2]*ref_k.shape[-2] / 53090100 * 2
+ chunk_size = int(visual_q_shape[-2] / number_chunks)
+ chunks =torch.split(visual_q, chunk_size)
+ maps_lists = [ [] for _ in ref_target_masks]
+ for q_chunk in chunks:
+ attn = q_chunk @ ref_k.transpose(-2, -1)
+ x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
+ del attn
+ ref_target_masks = ref_target_masks.to(dtype)
+ x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
+
+ for class_idx, ref_target_mask in enumerate(ref_target_masks):
+ ref_target_mask = ref_target_mask[None, None, None, ...]
+ x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
+ x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
+ maps_lists[class_idx].append(x_ref_attnmap)
+
+ del x_ref_attn_map_source
- if attn_bias is not None: attn += attn_bias
+ x_ref_attn_maps = []
+ for class_idx, maps_list in enumerate(maps_lists):
+ attn_map_fuse = torch.concat(maps_list, dim= -1)
+ attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1)
+ x_ref_attn_maps.append( attn_map_fuse )
- x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
+ return torch.concat(x_ref_attn_maps, dim=0)
+
+def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count):
+ dtype = visual_q.dtype
+ ref_k = ref_k.to(dtype).to(visual_q.device)
+ scale = 1.0 / visual_q.shape[-1] ** 0.5
+ visual_q = visual_q * scale
+ visual_q = visual_q.transpose(1, 2)
+ ref_k = ref_k.transpose(1, 2)
+ attn = visual_q @ ref_k.transpose(-2, -1)
+
+ x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens
+ del attn
x_ref_attn_maps = []
- ref_target_masks = ref_target_masks.to(visual_q.dtype)
- x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype)
+ ref_target_masks = ref_target_masks.to(dtype)
+ x_ref_attn_map_source = x_ref_attn_map_source.to(dtype)
for class_idx, ref_target_mask in enumerate(ref_target_masks):
ref_target_mask = ref_target_mask[None, None, None, ...]
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H
-
- if mode == 'mean':
- x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
- elif mode == 'max':
- x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens
-
+ x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens (mean of heads)
x_ref_attn_maps.append(x_ref_attnmap)
- del attn
del x_ref_attn_map_source
return torch.concat(x_ref_attn_maps, dim=0)
-
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0):
"""Args:
query (torch.tensor): B M H K
@@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
N_t, N_h, N_w = shape
x_seqlens = N_h * N_w
+ if x_seqlens <= 1508:
+ split_num = 10 # 540p
+ else:
+ split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p
+
ref_k = ref_k[:, :x_seqlens]
if ref_images_count > 0 :
visual_q_shape = visual_q.shape
@@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli
split_chunk = heads // split_num
- for i in range(split_num):
- x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
- x_ref_attn_maps += x_ref_attn_maps_perhead
+ if split_chunk == 1:
+ for i in range(split_num):
+ x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count)
+ x_ref_attn_maps += x_ref_attn_maps_perhead
+ else:
+ for i in range(split_num):
+ x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count)
+ x_ref_attn_maps += x_ref_attn_maps_perhead
x_ref_attn_maps /= split_num
return x_ref_attn_maps
@@ -158,7 +196,6 @@ def __init__(self,
self.base = 10000
- @lru_cache(maxsize=32)
def precompute_freqs_cis_1d(self, pos_indices):
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
@@ -167,7 +204,7 @@ def precompute_freqs_cis_1d(self, pos_indices):
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
return freqs
- def forward(self, x, pos_indices):
+ def forward(self, qlist, pos_indices, cache_entry = None):
"""1D RoPE.
Args:
@@ -176,16 +213,26 @@ def forward(self, x, pos_indices):
Returns:
query with the same shape as input.
"""
- freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
-
- x_ = x.float()
-
- freqs_cis = freqs_cis.float().to(x.device)
- cos, sin = freqs_cis.cos(), freqs_cis.sin()
- cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
- x_ = (x_ * cos) + (rotate_half(x_) * sin)
-
- return x_.type_as(x)
+ xq= qlist[0]
+ qlist.clear()
+ cache = get_cache("multitalk_rope")
+ freqs_cis= cache.get(cache_entry, None)
+ if freqs_cis is None:
+ freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices)
+ cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0)
+ # cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
+ # real * cos - imag * sin
+ # imag * cos + real * sin
+ xq_dtype = xq.dtype
+ xq_out = xq.to(torch.float)
+ xq = None
+ xq_rot = rotate_half(xq_out)
+ xq_out *= cos
+ xq_rot *= sin
+ xq_out += xq_rot
+ del xq_rot
+ xq_out = xq_out.to(xq_dtype)
+ return xq_out
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index c3ad01294..4d59aa354 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -1,5 +1,6 @@
import torch
import numpy as np
+import gradio as gr
def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
@@ -116,6 +117,11 @@ def query_model_def(base_model_type, model_def):
extra_model_def["no_background_removal"] = True
# extra_model_def["at_least_one_image_ref_needed"] = True
+
+ if base_model_type in ["phantom_1.3B", "phantom_14B"]:
+ extra_model_def["one_image_ref_needed"] = True
+
+
return extra_model_def
@staticmethod
@@ -235,6 +241,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if "I" in video_prompt_type:
video_prompt_type = video_prompt_type.replace("KI", "QKI")
ui_defaults["video_prompt_type"] = video_prompt_type
+
+ if settings_version < 2.28:
+ if base_model_type in "infinitetalk":
+ video_prompt_type = ui_defaults.get("video_prompt_type", "")
+ if "U" in video_prompt_type:
+ video_prompt_type = video_prompt_type.replace("U", "RU")
+ ui_defaults["video_prompt_type"] = video_prompt_type
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
@@ -309,3 +323,11 @@ def validate_generative_settings(base_model_type, model_def, inputs):
if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None:
video_prompt_type = video_prompt_type.replace("I", "").replace("K","")
inputs["video_prompt_type"] = video_prompt_type
+
+
+ if base_model_type in ["vace_standin_14B"]:
+ image_refs = inputs["image_refs"]
+ video_prompt_type = inputs["video_prompt_type"]
+ if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type:
+ gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.")
+
diff --git a/requirements.txt b/requirements.txt
index a2e1fde85..7704d6db9 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -20,7 +20,8 @@ soundfile
mutagen
pyloudnorm
librosa==0.11.0
-
+speechbrain==1.0.3
+
# UI & interaction
gradio==5.23.0
dashscope
@@ -43,7 +44,7 @@ pydantic==2.10.6
# Math & modeling
torchdiffeq>=0.2.5
tensordict>=0.6.1
-mmgp==3.5.10
+mmgp==3.5.11
peft==0.15.0
matplotlib
diff --git a/shared/attention.py b/shared/attention.py
index a95332dc5..cc6ece02d 100644
--- a/shared/attention.py
+++ b/shared/attention.py
@@ -3,6 +3,7 @@
from importlib.metadata import version
from mmgp import offload
import torch.nn.functional as F
+import warnings
major, minor = torch.cuda.get_device_capability(None)
bfloat16_supported = major >= 8
@@ -42,34 +43,51 @@ def sageattn_varlen_wrapper(
sageattn_varlen_wrapper = None
-import warnings
-
try:
- from sageattention import sageattn
- from .sage2_core import sageattn as alt_sageattn, is_sage2_supported
+ from .sage2_core import sageattn as sageattn2, is_sage2_supported
sage2_supported = is_sage2_supported()
except ImportError:
- sageattn = None
- alt_sageattn = None
+ sageattn2 = None
sage2_supported = False
-# @torch.compiler.disable()
-def sageattn_wrapper(
+@torch.compiler.disable()
+def sageattn2_wrapper(
qkv_list,
attention_length
):
q,k, v = qkv_list
- if True:
- qkv_list = [q,k,v]
- del q, k ,v
- o = alt_sageattn(qkv_list, tensor_layout="NHD")
- else:
- o = sageattn(q, k, v, tensor_layout="NHD")
- del q, k ,v
+ qkv_list = [q,k,v]
+ del q, k ,v
+ o = sageattn2(qkv_list, tensor_layout="NHD")
+ qkv_list.clear()
+
+ return o
+try:
+ from sageattn import sageattn_blackwell as sageattn3
+except ImportError:
+ sageattn3 = None
+
+@torch.compiler.disable()
+def sageattn3_wrapper(
+ qkv_list,
+ attention_length
+ ):
+ q,k, v = qkv_list
+ # qkv_list = [q,k,v]
+ # del q, k ,v
+ # o = sageattn3(qkv_list, tensor_layout="NHD")
+ q = q.transpose(1,2)
+ k = k.transpose(1,2)
+ v = v.transpose(1,2)
+ o = sageattn3(q, k, v)
+ o = o.transpose(1,2)
qkv_list.clear()
return o
+
+
+
# try:
# if True:
# from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda
@@ -94,7 +112,7 @@ def sageattn_wrapper(
# return o
# except ImportError:
-# sageattn = sageattn_qk_int8_pv_fp8_window_cuda
+# sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda
@torch.compiler.disable()
def sdpa_wrapper(
@@ -124,21 +142,28 @@ def get_attention_modes():
ret.append("xformers")
if sageattn_varlen_wrapper != None:
ret.append("sage")
- if sageattn != None and version("sageattention").startswith("2") :
+ if sageattn2 != None and version("sageattention").startswith("2") :
ret.append("sage2")
+ if sageattn3 != None: # and version("sageattention").startswith("3") :
+ ret.append("sage3")
return ret
def get_supported_attention_modes():
ret = get_attention_modes()
+ major, minor = torch.cuda.get_device_capability()
+ if major < 10:
+ if "sage3" in ret:
+ ret.remove("sage3")
+
if not sage2_supported:
if "sage2" in ret:
ret.remove("sage2")
- major, minor = torch.cuda.get_device_capability()
if major < 7:
if "sage" in ret:
ret.remove("sage")
+
return ret
__all__ = [
@@ -201,7 +226,7 @@ def pay_attention(
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
- if b > 1 and k_lens != None and attn in ("sage2", "sdpa"):
+ if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
assert attention_mask == None
# Poor's man var k len attention
assert q_lens == None
@@ -234,7 +259,7 @@ def pay_attention(
q_chunks, k_chunks, v_chunks = None, None, None
o = torch.cat(o, dim = 0)
return o
- elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"):
+ elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"):
assert b == 1
szq = q_lens[0].item() if q_lens != None else lq
szk = k_lens[0].item() if k_lens != None else lk
@@ -284,13 +309,19 @@ def pay_attention(
max_seqlen_q=lq,
max_seqlen_kv=lk,
).unflatten(0, (b, lq))
+ elif attn=="sage3":
+ import math
+ if cross_attn or True:
+ qkv_list = [q,k,v]
+ del q,k,v
+ x = sageattn3_wrapper(qkv_list, lq)
elif attn=="sage2":
import math
if cross_attn or True:
qkv_list = [q,k,v]
del q,k,v
- x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0)
+ x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0)
# else:
# layer = offload.shared_state["layer"]
# embed_sizes = offload.shared_state["embed_sizes"]
diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py
new file mode 100644
index 000000000..6bd86a8d3
--- /dev/null
+++ b/shared/gradio/gallery.py
@@ -0,0 +1,496 @@
+from __future__ import annotations
+import os, io, tempfile, mimetypes
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal
+
+import gradio as gr
+import PIL
+from PIL import Image as PILImage
+
+FilePath = str
+ImageLike = Union["PIL.Image.Image", Any]
+
+IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"}
+VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"}
+
+def get_state(state):
+ return state if isinstance(state, dict) else state.value
+
+def get_list( objs):
+ if objs is None:
+ return []
+ return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs]
+
+class AdvancedMediaGallery:
+ def __init__(
+ self,
+ label: str = "Media",
+ *,
+ media_mode: Literal["image", "video"] = "image",
+ height = None,
+ columns: Union[int, Tuple[int, ...]] = 6,
+ show_label: bool = True,
+ initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None,
+ elem_id: Optional[str] = None,
+ elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",),
+ accept_filter: bool = True, # restrict Add-button dialog to allowed extensions
+ single_image_mode: bool = False, # start in single-image mode (Add replaces)
+ ):
+ assert media_mode in ("image", "video")
+ self.label = label
+ self.media_mode = media_mode
+ self.height = height
+ self.columns = columns
+ self.show_label = show_label
+ self.elem_id = elem_id
+ self.elem_classes = list(elem_classes) if elem_classes else None
+ self.accept_filter = accept_filter
+
+ items = self._normalize_initial(initial or [], media_mode)
+
+ # Components (filled on mount)
+ self.container: Optional[gr.Column] = None
+ self.gallery: Optional[gr.Gallery] = None
+ self.upload_btn: Optional[gr.UploadButton] = None
+ self.btn_remove: Optional[gr.Button] = None
+ self.btn_left: Optional[gr.Button] = None
+ self.btn_right: Optional[gr.Button] = None
+ self.btn_clear: Optional[gr.Button] = None
+
+ # Single dict state
+ self.state: Optional[gr.State] = None
+ self._initial_state: Dict[str, Any] = {
+ "items": items,
+ "selected": (len(items) - 1) if items else None,
+ "single": bool(single_image_mode),
+ "mode": self.media_mode,
+ }
+
+ # ---------------- helpers ----------------
+
+ def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]:
+ out: List[Any] = []
+ if mode == "image":
+ for it in items:
+ p = self._ensure_image_item(it)
+ if p is not None:
+ out.append(p)
+ else:
+ for it in items:
+ if isinstance(item, tuple): item = item[0]
+ if isinstance(it, str) and self._is_video_path(it):
+ out.append(os.path.abspath(it))
+ return out
+
+ def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]:
+ # Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path
+ if isinstance(item, tuple): item = item[0]
+ if isinstance(item, str):
+ return os.path.abspath(item) if self._is_image_path(item) else None
+ if PILImage is None:
+ return None
+ try:
+ if isinstance(item, PILImage.Image):
+ img = item
+ else:
+ import numpy as np # type: ignore
+ if isinstance(item, np.ndarray):
+ img = PILImage.fromarray(item)
+ elif hasattr(item, "read"):
+ data = item.read()
+ img = PILImage.open(io.BytesIO(data)).convert("RGBA")
+ else:
+ return None
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
+ img.save(tmp.name)
+ return tmp.name
+ except Exception:
+ return None
+
+ @staticmethod
+ def _extract_path(obj: Any) -> Optional[str]:
+ # Try to get a filesystem path (for mode filtering); otherwise None.
+ if isinstance(obj, str):
+ return obj
+ try:
+ import pathlib
+ if isinstance(obj, pathlib.Path): # type: ignore
+ return str(obj)
+ except Exception:
+ pass
+ if isinstance(obj, dict):
+ return obj.get("path") or obj.get("name")
+ for attr in ("path", "name"):
+ if hasattr(obj, attr):
+ try:
+ val = getattr(obj, attr)
+ if isinstance(val, str):
+ return val
+ except Exception:
+ pass
+ return None
+
+ @staticmethod
+ def _is_image_path(p: str) -> bool:
+ ext = os.path.splitext(p)[1].lower()
+ if ext in IMAGE_EXTS:
+ return True
+ mt, _ = mimetypes.guess_type(p)
+ return bool(mt and mt.startswith("image/"))
+
+ @staticmethod
+ def _is_video_path(p: str) -> bool:
+ ext = os.path.splitext(p)[1].lower()
+ if ext in VIDEO_EXTS:
+ return True
+ mt, _ = mimetypes.guess_type(p)
+ return bool(mt and mt.startswith("video/"))
+
+ def _filter_items_by_mode(self, items: List[Any]) -> List[Any]:
+ # Enforce image-only or video-only collection regardless of how files were added.
+ out: List[Any] = []
+ if self.media_mode == "image":
+ for it in items:
+ p = self._extract_path(it)
+ if p is None:
+ # No path: likely an image object added programmatically => keep
+ out.append(it)
+ elif self._is_image_path(p):
+ out.append(os.path.abspath(p))
+ else:
+ for it in items:
+ p = self._extract_path(it)
+ if p is not None and self._is_video_path(p):
+ out.append(os.path.abspath(p))
+ return out
+
+ @staticmethod
+ def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]:
+ # Keep it simple: dedupe by path when available, else allow duplicates.
+ seen_paths = set()
+ def key(x: Any) -> Optional[str]:
+ if isinstance(x, str): return os.path.abspath(x)
+ try:
+ import pathlib
+ if isinstance(x, pathlib.Path): # type: ignore
+ return os.path.abspath(str(x))
+ except Exception:
+ pass
+ if isinstance(x, dict):
+ p = x.get("path") or x.get("name")
+ return os.path.abspath(p) if isinstance(p, str) else None
+ for attr in ("path", "name"):
+ if hasattr(x, attr):
+ try:
+ v = getattr(x, attr)
+ return os.path.abspath(v) if isinstance(v, str) else None
+ except Exception:
+ pass
+ return None
+
+ out: List[Any] = []
+ for lst in (cur, add):
+ for it in lst:
+ k = key(it)
+ if k is None or k not in seen_paths:
+ out.append(it)
+ if k is not None:
+ seen_paths.add(k)
+ return out
+
+ @staticmethod
+ def _paths_from_payload(payload: Any) -> List[Any]:
+ # Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly.
+ if payload is None:
+ return []
+ if isinstance(payload, (list, tuple, set)):
+ return list(payload)
+ return [payload]
+
+ # ---------------- event handlers ----------------
+
+ def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
+ # Mirror the selected index into state and the gallery (server-side selected_index)
+ idx = None
+ if evt is not None and hasattr(evt, "index"):
+ ix = evt.index
+ if isinstance(ix, int):
+ idx = ix
+ elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int):
+ if isinstance(self.columns, int) and len(ix) >= 2:
+ idx = ix[0] * max(1, int(self.columns)) + ix[1]
+ else:
+ idx = ix[0]
+ st = get_state(state)
+ n = len(get_list(gallery))
+ sel = idx if (idx is not None and 0 <= idx < n) else None
+ st["selected"] = sel
+ return gr.update(selected_index=sel), st
+
+ def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
+ # Fires when users add/drag/drop/delete via the Gallery itself.
+ items_filtered = self._filter_items_by_mode(list(value or []))
+ st = get_state(state)
+ st["items"] = items_filtered
+ # Keep selection if still valid, else default to last
+ old_sel = st.get("selected", None)
+ if old_sel is None or not (0 <= old_sel < len(items_filtered)):
+ new_sel = (len(items_filtered) - 1) if items_filtered else None
+ else:
+ new_sel = old_sel
+ st["selected"] = new_sel
+ return gr.update(value=items_filtered, selected_index=new_sel), st
+
+
+ def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
+ """
+ Insert added items right AFTER the currently selected index.
+ Keeps the same ordering as chosen in the file picker, dedupes by path,
+ and re-selects the last inserted item.
+ """
+ # New items (respect image/video mode)
+ new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
+
+ st = get_state(state)
+ cur: List[Any] = get_list(gallery)
+ sel = st.get("selected", None)
+ if sel is None:
+ sel = (len(cur) -1) if len(cur)>0 else 0
+ single = bool(st.get("single", False))
+
+ # Nothing to add: keep as-is
+ if not new_items:
+ return gr.update(value=cur, selected_index=st.get("selected")), st
+
+ # Single-image mode: replace
+ if single:
+ st["items"] = [new_items[-1]]
+ st["selected"] = 0
+ return gr.update(value=st["items"], selected_index=0), st
+
+ # ---------- helpers ----------
+ def key_of(it: Any) -> Optional[str]:
+ # Prefer class helper if present
+ if hasattr(self, "_extract_path"):
+ p = self._extract_path(it) # type: ignore
+ else:
+ p = it if isinstance(it, str) else None
+ if p is None and isinstance(it, dict):
+ p = it.get("path") or it.get("name")
+ if p is None and hasattr(it, "path"):
+ try: p = getattr(it, "path")
+ except Exception: p = None
+ if p is None and hasattr(it, "name"):
+ try: p = getattr(it, "name")
+ except Exception: p = None
+ return os.path.abspath(p) if isinstance(p, str) else None
+
+ # Dedupe the incoming batch by path, preserve order
+ seen_new = set()
+ incoming: List[Any] = []
+ for it in new_items:
+ k = key_of(it)
+ if k is None or k not in seen_new:
+ incoming.append(it)
+ if k is not None:
+ seen_new.add(k)
+
+ # Remove any existing occurrences of the incoming items from current list,
+ # BUT keep the currently selected item even if it's also in incoming.
+ cur_clean: List[Any] = []
+ # sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
+ # for idx, it in enumerate(cur):
+ # k = key_of(it)
+ # if it is sel_item:
+ # cur_clean.append(it)
+ # continue
+ # if k is not None and k in seen_new:
+ # continue # drop duplicate; we'll reinsert at the target spot
+ # cur_clean.append(it)
+
+ # # Compute insertion position: right AFTER the (possibly shifted) selected item
+ # if sel_item is not None:
+ # # find sel_item's new index in cur_clean
+ # try:
+ # pos_sel = cur_clean.index(sel_item)
+ # except ValueError:
+ # # Shouldn't happen, but fall back to end
+ # pos_sel = len(cur_clean) - 1
+ # insert_pos = pos_sel + 1
+ # else:
+ # insert_pos = len(cur_clean) # no selection -> append at end
+ insert_pos = min(sel, len(cur) -1)
+ cur_clean = cur
+ # Build final list and selection
+ merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:]
+ new_sel = insert_pos + len(incoming) # select the last inserted item
+
+ st["items"] = merged
+ st["selected"] = new_sel
+ return gr.update(value=merged, selected_index=new_sel), st
+
+ def _on_remove(self, state: Dict[str, Any], gallery) :
+ st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
+ if sel is None or not (0 <= sel < len(items)):
+ return gr.update(value=items, selected_index=st.get("selected")), st
+ items.pop(sel)
+ if not items:
+ st["items"] = []; st["selected"] = None
+ return gr.update(value=[], selected_index=None), st
+ new_sel = min(sel, len(items) - 1)
+ st["items"] = items; st["selected"] = new_sel
+ return gr.update(value=items, selected_index=new_sel), st
+
+ def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
+ st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
+ if sel is None or not (0 <= sel < len(items)):
+ return gr.update(value=items, selected_index=sel), st
+ j = sel + delta
+ if j < 0 or j >= len(items):
+ return gr.update(value=items, selected_index=sel), st
+ items[sel], items[j] = items[j], items[sel]
+ st["items"] = items; st["selected"] = j
+ return gr.update(value=items, selected_index=j), st
+
+ def _on_clear(self, state: Dict[str, Any]) :
+ st = {"items": [], "selected": None, "single": state.get("single", False), "mode": self.media_mode}
+ return gr.update(value=[], selected_index=None), st
+
+ def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
+ st = get_state(state); st["single"] = bool(to_single)
+ items: List[Any] = list(st["items"]); sel = st.get("selected", None)
+ if st["single"]:
+ keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None)
+ items = [keep] if keep is not None else []
+ sel = 0 if items else None
+ st["items"] = items; st["selected"] = sel
+
+ upload_update = gr.update(file_count=("single" if st["single"] else "multiple"))
+ left_update = gr.update(visible=not st["single"])
+ right_update = gr.update(visible=not st["single"])
+ clear_update = gr.update(visible=not st["single"])
+ gallery_update= gr.update(value=items, selected_index=sel)
+
+ return upload_update, left_update, right_update, clear_update, gallery_update, st
+
+ # ---------------- build & wire ----------------
+
+ def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
+ if parent is not None:
+ with parent:
+ col = self._build_ui()
+ else:
+ col = self._build_ui()
+ if not update_form:
+ self._wire_events()
+ return col
+
+ def _build_ui(self) -> gr.Column:
+ with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
+ self.container = col
+
+ self.state = gr.State(dict(self._initial_state))
+
+ self.gallery = gr.Gallery(
+ label=self.label,
+ value=self._initial_state["items"],
+ height=self.height,
+ columns=self.columns,
+ show_label=self.show_label,
+ preview= True,
+ selected_index=self._initial_state["selected"], # server-side selection
+ )
+
+ # One-line controls
+ exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
+ with gr.Row(equal_height=True, elem_classes=["amg-controls"]):
+ self.upload_btn = gr.UploadButton(
+ "Set" if self._initial_state["single"] else "Add",
+ file_types=exts,
+ file_count=("single" if self._initial_state["single"] else "multiple"),
+ variant="primary",
+ size="sm",
+ min_width=1,
+ )
+ self.btn_remove = gr.Button("Remove", size="sm", min_width=1)
+ self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
+ self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
+ self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
+
+ return col
+
+ def _wire_events(self):
+ # Selection: mirror into state and keep gallery.selected_index in sync
+ self.gallery.select(
+ self._on_select,
+ inputs=[self.state, self.gallery],
+ outputs=[self.gallery, self.state],
+ )
+
+ # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
+ self.gallery.change(
+ self._on_gallery_change,
+ inputs=[self.gallery, self.state],
+ outputs=[self.gallery, self.state],
+ )
+
+ # Add via UploadButton
+ self.upload_btn.upload(
+ self._on_add,
+ inputs=[self.upload_btn, self.state, self.gallery],
+ outputs=[self.gallery, self.state],
+ )
+
+ # Remove selected
+ self.btn_remove.click(
+ self._on_remove,
+ inputs=[self.state, self.gallery],
+ outputs=[self.gallery, self.state],
+ )
+
+ # Reorder using selected index, keep same item selected
+ self.btn_left.click(
+ lambda st, gallery: self._on_move(-1, st, gallery),
+ inputs=[self.state, self.gallery],
+ outputs=[self.gallery, self.state],
+ )
+ self.btn_right.click(
+ lambda st, gallery: self._on_move(+1, st, gallery),
+ inputs=[self.state, self.gallery],
+ outputs=[self.gallery, self.state],
+ )
+
+ # Clear all
+ self.btn_clear.click(
+ self._on_clear,
+ inputs=[self.state],
+ outputs=[self.gallery, self.state],
+ )
+
+ # ---------------- public API ----------------
+
+ def set_one_image_mode(self, enabled: bool = True):
+ """Toggle single-image mode at runtime."""
+ return (
+ self._on_toggle_single,
+ [gr.State(enabled), self.state],
+ [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state],
+ )
+
+ def get_toggable_elements(self):
+ return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state]
+
+# import gradio as gr
+
+# with gr.Blocks() as demo:
+# amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8)
+# amg.mount()
+# g = amg.gallery
+# # buttons to switch modes live (optional)
+# def process(g):
+# pass
+# with gr.Row():
+# gr.Button("toto").click(process, g)
+# gr.Button("ONE image").click(*amg.set_one_image_mode(True))
+# gr.Button("MULTI image").click(*amg.set_one_image_mode(False))
+
+# demo.launch()
diff --git a/shared/utils/qwen_vl_utils.py b/shared/utils/qwen_vl_utils.py
index 3c682e6ad..3fe02cc1a 100644
--- a/shared/utils/qwen_vl_utils.py
+++ b/shared/utils/qwen_vl_utils.py
@@ -9,7 +9,6 @@
import sys
import time
import warnings
-from functools import lru_cache
from io import BytesIO
import requests
@@ -257,7 +256,6 @@ def _read_video_decord(ele: dict,) -> torch.Tensor:
FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None)
-@lru_cache(maxsize=1)
def get_video_reader_backend() -> str:
if FORCE_QWENVL_VIDEO_READER is not None:
video_reader_backend = FORCE_QWENVL_VIDEO_READER
diff --git a/wgp.py b/wgp.py
index a151252ba..c757ad87f 100644
--- a/wgp.py
+++ b/wgp.py
@@ -49,6 +49,7 @@
from preprocessing.matanyone import app as matanyone_app
from tqdm import tqdm
import requests
+from shared.gradio.gallery import AdvancedMediaGallery
# import torch._dynamo as dynamo
# dynamo.config.recompile_limit = 2000 # default is 256
@@ -58,9 +59,9 @@
AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
-target_mmgp_version = "3.5.10"
-WanGP_version = "8.21"
-settings_version = 2.27
+target_mmgp_version = "3.5.11"
+WanGP_version = "8.3"
+settings_version = 2.28
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -466,7 +467,7 @@ def process_prompt_and_add_tasks(state, model_choice):
image_mask = None
if "G" in video_prompt_type:
- gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ")
+ gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ")
else:
denoising_strength = 1.0
if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]:
@@ -4733,46 +4734,61 @@ def remove_temp_filenames(temp_filenames_list):
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
+
if video_guide is not None:
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
- if infinitetalk and video_guide is not None:
- src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True)
- new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
- src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
- refresh_preview["video_guide"] = src_image
- src_video = convert_image_to_tensor(src_image).unsqueeze(1)
- if sample_fit_canvas != None:
- image_size = src_video.shape[-2:]
- sample_fit_canvas = None
- if ltxv and video_guide is not None:
- preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
- status_info = "Extracting " + processes_names[preprocess_type]
- send_cmd("progress", [0, get_latest_status(state, status_info)])
- # start one frame ealier to facilitate latents merging later
- src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
- if src_video != None:
- src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
- refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
- src_video = src_video.permute(3, 0, 1, 2)
- src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
- if sample_fit_canvas != None:
- image_size = src_video.shape[-2:]
- sample_fit_canvas = None
- if t2v and "G" in video_prompt_type:
- video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps)
- if video_guide_processed == None:
- src_video = pre_video_guide
- else:
- if sample_fit_canvas != None:
- image_size = video_guide_processed.shape[-3: -1]
+ if ltxv:
+ preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
+ status_info = "Extracting " + processes_names[preprocess_type]
+ send_cmd("progress", [0, get_latest_status(state, status_info)])
+ # start one frame ealier to facilitate latents merging later
+ src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
+ if src_video != None:
+ src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
+ refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
+ src_video = src_video.permute(3, 0, 1, 2)
+ src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
+ if sample_fit_canvas != None:
+ image_size = src_video.shape[-2:]
+ sample_fit_canvas = None
+
+ elif hunyuan_custom_edit:
+ if "P" in video_prompt_type:
+ progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
+ else:
+ progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
+
+ send_cmd("progress", progress_args)
+ src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
+ refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
+ if src_mask != None:
+ refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
+
+ elif "R" in video_prompt_type: # sparse video to video
+ src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
+ new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
+ src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ refresh_preview["video_guide"] = src_image
+ src_video = convert_image_to_tensor(src_image).unsqueeze(1)
+ if sample_fit_canvas != None:
+ image_size = src_video.shape[-2:]
sample_fit_canvas = None
- src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
- if pre_video_guide != None:
- src_video = torch.cat( [pre_video_guide, src_video], dim=1)
+
+ elif "G" in video_prompt_type: # video to video
+ video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps)
+ if video_guide_processed is None:
+ src_video = pre_video_guide
+ else:
+ if sample_fit_canvas != None:
+ image_size = video_guide_processed.shape[-3: -1]
+ sample_fit_canvas = None
+ src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
+ if pre_video_guide != None:
+ src_video = torch.cat( [pre_video_guide, src_video], dim=1)
if vace :
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
@@ -4834,17 +4850,8 @@ def remove_temp_filenames(temp_filenames_list):
if sample_fit_canvas != None:
image_size = src_video[0].shape[-2:]
sample_fit_canvas = None
- elif hunyuan_custom_edit:
- if "P" in video_prompt_type:
- progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
- else:
- progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
- send_cmd("progress", progress_args)
- src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
- refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
- if src_mask != None:
- refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
+
if len(refresh_preview) > 0:
new_inputs= locals()
new_inputs.update(refresh_preview)
@@ -6013,10 +6020,16 @@ def video_to_source_video(state, input_file_list, choice):
def image_to_ref_image_add(state, input_file_list, choice, target, target_name):
file_list, file_settings_list = get_file_list(state, input_file_list)
if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update()
- gr.Info(f"Selected Image was added to {target_name}")
- if target == None:
- target =[]
- target.append( file_list[choice])
+ model_type = state["model_type"]
+ model_def = get_model_def(model_type)
+ if model_def.get("one_image_ref_needed", False):
+ gr.Info(f"Selected Image was set to {target_name}")
+ target =[file_list[choice]]
+ else:
+ gr.Info(f"Selected Image was added to {target_name}")
+ if target == None:
+ target =[]
+ target.append( file_list[choice])
return target
def image_to_ref_image_set(state, input_file_list, choice, target, target_name):
@@ -6229,6 +6242,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw
if not "WanGP" in configs.get("type", ""): configs = None
except:
configs = None
+ if configs is None: return None, False
current_model_filename = state["model_filename"]
@@ -6615,11 +6629,12 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt):
- video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI")
+ video_prompt_type = del_in_sequence(video_prompt_type, "RGUVQKI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
control_video_visible = "V" in video_prompt_type
ref_images_visible = "I" in video_prompt_type
- return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible )
+ denoising_strength_visible = "G" in video_prompt_type
+ return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible )
# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
@@ -6996,6 +7011,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
v2i_switch_supported = (vace or t2v or standin) and not image_outputs
ti2v_2_2 = base_model_type in ["ti2v_2_2"]
+ def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ):
+ with gr.Row(visible = visible) as gallery_row:
+ gallery_amg = AdvancedMediaGallery(media_mode="image", height=None, columns=4, label=label, initial = value , single_image_mode = single_image_mode )
+ gallery_amg.mount(update_form=update_form)
+ return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements()
+
image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 )
if not v2i_switch_supported and not image_outputs:
image_mode_value = 0
@@ -7009,15 +7030,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"):
pass
-
with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column:
if vace or infinitetalk:
image_prompt_type_value= ui_defaults.get("image_prompt_type","")
image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value
image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3)
- image_start = gr.Gallery(visible = False)
- image_end = gr.Gallery(visible = False)
+ image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
+ image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None))
model_mode = gr.Dropdown(visible = False)
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
@@ -7034,13 +7054,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type_choices += [("Continue Video", "V")]
image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3)
- # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
- image_start = gr.Gallery(preview= True,
- label="Images as starting points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
- image_end = gr.Gallery(preview= True,
- label="Images as ending points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
+ image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
+ image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value )
video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
if not diffusion_forcing:
model_mode = gr.Dropdown(
@@ -7061,8 +7076,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" )
elif recammaster:
image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V")
- image_start = gr.Gallery(value = None, visible = False)
- image_end = gr.Gallery(value = None, visible= False)
+ image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
+ image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),)
model_mode = gr.Dropdown(
choices=[
@@ -7095,21 +7110,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3)
any_start_image = True
any_end_image = True
- image_start = gr.Gallery(preview= True,
- label="Images as starting points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value)
-
- image_end = gr.Gallery(preview= True,
- label="Images as ending points for new videos", type ="pil", #file_types= "image",
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
+ image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
+ image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value )
if hunyuan_i2v:
video_source = gr.Video(value=None, visible=False)
else:
video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
else:
image_prompt_type = gr.Radio(choices=[("", "")], value="")
- image_start = gr.Gallery(value=None)
- image_end = gr.Gallery(value=None)
+ image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
+ image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
video_source = gr.Video(value=None, visible=False)
model_mode = gr.Dropdown(value=None, visible=False)
keep_frames_video_source = gr.Text(visible=False)
@@ -7184,12 +7194,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
if infinitetalk:
video_prompt_type_video_guide_alt = gr.Dropdown(
choices=[
- ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"),
- ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"),
- ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
+ ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
+ ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
+ ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
+ ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
+ ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
],
- value=filter_letters(video_prompt_type_value, "UVQKI"),
+ value=filter_letters(video_prompt_type_value, "RGUVQKI"),
label="Video to Video", scale = 3, visible= True, show_label= False,
)
else:
@@ -7318,11 +7330,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image
- image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images" + (" (each Image will start a new Clip)" if infinitetalk else ""),
- type ="pil", show_label= True,
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
- value= ui_defaults.get("image_refs", None),
- )
+
+ image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
+ image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "")
+ image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
@@ -7935,7 +7946,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs,
- min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] # presets_column,
+ min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] + image_start_extra + image_end_extra + image_refs_extra # presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@@ -7953,11 +7964,11 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
- image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] )
+ image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start_row, image_end_row, video_source, keep_frames_video_source] )
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
- video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col])
+ video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
- video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ])
+ video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])
@@ -8348,6 +8359,7 @@ def check(mode):
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
+ ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"),
],
value= attention_mode,
label="Attention Type",
@@ -8663,7 +8675,7 @@ def generate_about_tab():
gr.Markdown("- Blackforest Labs for the innovative Flux image generators (https://github.com/black-forest-labs/flux)")
gr.Markdown("- Alibaba Qwen Team for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)")
gr.Markdown("- Lightricks for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)")
- gr.Markdown("- Hugging Face for the providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)")
+ gr.Markdown("- Hugging Face for providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)")
gr.Markdown(" Huge acknowledgments to these great open source projects used in WanGP:")
gr.Markdown("- Rife: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
gr.Markdown("- DwPose: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
@@ -8672,7 +8684,7 @@ def generate_about_tab():
gr.Markdown("- Pyannote: speaker diarization (https://github.com/pyannote/pyannote-audio)")
gr.Markdown(" Special thanks to the following people for their support:")
- gr.Markdown("- Cocktail Peanuts : QA and simple installation via Pinokio.computer")
+ gr.Markdown("- Cocktail Peanuts : QA dpand simple installation via Pinokio.computer")
gr.Markdown("- Tophness : created (former) multi tabs and queuing frameworks")
gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance")
gr.Markdown("- Remade_AI : for their awesome Loras collection")
From f2db023a3d314023736d2a5d5cd369b51ff4febd Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 2 Sep 2025 21:03:16 +0200
Subject: [PATCH 007/155] a release without a gaming breaking bug is not a
release
---
shared/gradio/gallery.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py
index 6bd86a8d3..ef53d4d86 100644
--- a/shared/gradio/gallery.py
+++ b/shared/gradio/gallery.py
@@ -397,6 +397,7 @@ def _build_ui(self) -> gr.Column:
columns=self.columns,
show_label=self.show_label,
preview= True,
+ type="pil",
selected_index=self._initial_state["selected"], # server-side selection
)
From 959ab9e0c13a31cb8973a43c1d7fa96149901319 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 2 Sep 2025 23:13:02 +0200
Subject: [PATCH 008/155] no more pain
---
README.md | 3 ++-
shared/gradio/gallery.py | 20 +++++++++++++-------
wgp.py | 30 ++++++++++++++++++------------
3 files changed, 33 insertions(+), 20 deletions(-)
diff --git a/README.md b/README.md
index be76cbcc8..eef872837 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
-### September 2 2025: WanGP v8.3 - At last the pain stops
+### September 2 2025: WanGP v8.31 - At last the pain stops
- This single new feature should give you the strength to face all the potential bugs of this new release:
**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.**
@@ -30,6 +30,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash.
+*update 8.31: one shouldnt talk about bugs if one dont want to attract bugs*
### August 29 2025: WanGP v8.21 - Here Goes Your Weekend
diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py
index ef53d4d86..2f30ef16e 100644
--- a/shared/gradio/gallery.py
+++ b/shared/gradio/gallery.py
@@ -224,7 +224,9 @@ def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
n = len(get_list(gallery))
sel = idx if (idx is not None and 0 <= idx < n) else None
st["selected"] = sel
- return gr.update(selected_index=sel), st
+ # return gr.update(selected_index=sel), st
+ # return gr.update(), st
+ return st
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
# Fires when users add/drag/drop/delete via the Gallery itself.
@@ -238,8 +240,10 @@ def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
else:
new_sel = old_sel
st["selected"] = new_sel
- return gr.update(value=items_filtered, selected_index=new_sel), st
+ # return gr.update(value=items_filtered, selected_index=new_sel), st
+ # return gr.update(value=items_filtered), st
+ return gr.update(), st
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
"""
@@ -338,7 +342,8 @@ def _on_remove(self, state: Dict[str, Any], gallery) :
return gr.update(value=[], selected_index=None), st
new_sel = min(sel, len(items) - 1)
st["items"] = items; st["selected"] = new_sel
- return gr.update(value=items, selected_index=new_sel), st
+ # return gr.update(value=items, selected_index=new_sel), st
+ return gr.update(value=items), st
def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
@@ -352,8 +357,8 @@ def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
return gr.update(value=items, selected_index=j), st
def _on_clear(self, state: Dict[str, Any]) :
- st = {"items": [], "selected": None, "single": state.get("single", False), "mode": self.media_mode}
- return gr.update(value=[], selected_index=None), st
+ st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
+ return gr.update(value=[], selected_index=0), st
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
st = get_state(state); st["single"] = bool(to_single)
@@ -397,7 +402,8 @@ def _build_ui(self) -> gr.Column:
columns=self.columns,
show_label=self.show_label,
preview= True,
- type="pil",
+ # type="pil",
+ file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
selected_index=self._initial_state["selected"], # server-side selection
)
@@ -424,7 +430,7 @@ def _wire_events(self):
self.gallery.select(
self._on_select,
inputs=[self.state, self.gallery],
- outputs=[self.gallery, self.state],
+ outputs=[self.state],
)
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
diff --git a/wgp.py b/wgp.py
index ec51a74ff..6527357b5 100644
--- a/wgp.py
+++ b/wgp.py
@@ -60,7 +60,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.11"
-WanGP_version = "8.3"
+WanGP_version = "8.31"
settings_version = 2.28
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -186,6 +186,17 @@ def compute_sliding_window_no(current_video_length, sliding_window_size, discard
return 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames))
+def clean_image_list(gradio_list):
+ if not isinstance(gradio_list, list): gradio_list = [gradio_list]
+ gradio_list = [ tup[0] if isinstance(tup, tuple) else tup for tup in gradio_list ]
+
+ if any( not isinstance(image, (Image.Image, str)) for image in gradio_list): return None
+ if any( isinstance(image, str) and not has_image_file_extension(image) for image in gradio_list): return None
+ gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ]
+ return gradio_list
+
+
+
def process_prompt_and_add_tasks(state, model_choice):
if state.get("validate_success",0) != 1:
@@ -436,11 +447,10 @@ def process_prompt_and_add_tasks(state, model_choice):
if image_refs == None or len(image_refs) == 0:
gr.Info("You must provide at least one Refererence Image")
return
- if any(isinstance(image[0], str) for image in image_refs) :
+ image_refs = clean_image_list(image_refs)
+ if image_refs == None :
gr.Info("A Reference Image should be an Image")
return
- if isinstance(image_refs, list):
- image_refs = [ convert_image(tup[0]) for tup in image_refs ]
else:
image_refs = None
@@ -497,12 +507,10 @@ def process_prompt_and_add_tasks(state, model_choice):
if image_start == None or isinstance(image_start, list) and len(image_start) == 0:
gr.Info("You must provide a Start Image")
return
- if not isinstance(image_start, list):
- image_start = [image_start]
- if not all( not isinstance(img[0], str) for img in image_start) :
+ image_start = clean_image_list(image_start)
+ if image_start == None :
gr.Info("Start Image should be an Image")
return
- image_start = [ convert_image(tup[0]) for tup in image_start ]
else:
image_start = None
@@ -510,15 +518,13 @@ def process_prompt_and_add_tasks(state, model_choice):
if image_end == None or isinstance(image_end, list) and len(image_end) == 0:
gr.Info("You must provide an End Image")
return
- if not isinstance(image_end, list):
- image_end = [image_end]
- if not all( not isinstance(img[0], str) for img in image_end) :
+ image_end = clean_image_list(image_end)
+ if image_end == None :
gr.Info("End Image should be an Image")
return
if len(image_start) != len(image_end):
gr.Info("The number of Start and End Images should be the same ")
return
- image_end = [ convert_image(tup[0]) for tup in image_end ]
else:
image_end = None
From 2ac31cc12a825ca3c6ea15116f2b79d7115914e8 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 2 Sep 2025 23:16:02 +0200
Subject: [PATCH 009/155] no more pain
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index eef872837..77f7fa895 100644
--- a/README.md
+++ b/README.md
@@ -27,10 +27,10 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p).
-- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash.
+- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ...
-*update 8.31: one shouldnt talk about bugs if one dont want to attract bugs*
+*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs*
### August 29 2025: WanGP v8.21 - Here Goes Your Weekend
From 0871a3be580f1453a7c294d87b1882547ac69083 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 3 Sep 2025 09:50:52 +0200
Subject: [PATCH 010/155] a few fixes
---
defaults/flux_dev_uso.json | 2 +-
models/wan/wan_handler.py | 4 ++--
requirements.txt | 2 +-
wgp.py | 39 +++++++++++++++++++++++---------------
4 files changed, 28 insertions(+), 19 deletions(-)
diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json
index ab5ac5459..8b5dbb603 100644
--- a/defaults/flux_dev_uso.json
+++ b/defaults/flux_dev_uso.json
@@ -10,7 +10,7 @@
"reference_image": true,
"flux-model": "flux-dev-uso"
},
- "prompt": "add a hat",
+ "prompt": "the man is wearing a hat",
"embedded_guidance_scale": 4,
"resolution": "1024x1024",
"batch_size": 1
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 4d59aa354..6d91fe243 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -118,8 +118,8 @@ def query_model_def(base_model_type, model_def):
# extra_model_def["at_least_one_image_ref_needed"] = True
- if base_model_type in ["phantom_1.3B", "phantom_14B"]:
- extra_model_def["one_image_ref_needed"] = True
+ # if base_model_type in ["phantom_1.3B", "phantom_14B"]:
+ # extra_model_def["one_image_ref_needed"] = True
return extra_model_def
diff --git a/requirements.txt b/requirements.txt
index 7704d6db9..cdc21ad53 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -44,7 +44,7 @@ pydantic==2.10.6
# Math & modeling
torchdiffeq>=0.2.5
tensordict>=0.6.1
-mmgp==3.5.11
+mmgp==3.5.12
peft==0.15.0
matplotlib
diff --git a/wgp.py b/wgp.py
index 6527357b5..f11afa94f 100644
--- a/wgp.py
+++ b/wgp.py
@@ -59,8 +59,8 @@
AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
-target_mmgp_version = "3.5.11"
-WanGP_version = "8.31"
+target_mmgp_version = "3.5.12"
+WanGP_version = "8.32"
settings_version = 2.28
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -1361,6 +1361,14 @@ def _parse_args():
help="save proprocessed audio track with extract speakers for debugging or editing"
)
+ parser.add_argument(
+ "--vram-safety-coefficient",
+ type=float,
+ default=0.8,
+ help="max VRAM (between 0 and 1) that should be allocated to preloaded models"
+ )
+
+
parser.add_argument(
"--share",
action="store_true",
@@ -2767,6 +2775,7 @@ def load_models(model_type, override_profile = -1):
transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype
transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype
perc_reserved_mem_max = args.perc_reserved_mem_max
+ vram_safety_coefficient = args.vram_safety_coefficient
model_file_list = [model_filename]
model_type_list = [model_type]
module_type_list = [None]
@@ -2803,7 +2812,7 @@ def load_models(model_type, override_profile = -1):
loras_transformer = ["transformer"]
if "transformer2" in pipe:
loras_transformer += ["transformer2"]
- offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs)
+ offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , vram_safety_coefficient = vram_safety_coefficient , convertWeightsFloatTo = transformer_dtype, **kwargs)
if len(args.gpu) > 0:
torch.set_default_device(args.gpu)
transformer_type = model_type
@@ -4164,6 +4173,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri
download_models()
gen = get_gen_info(state)
+ original_process_status = None
while True:
with gen_lock:
process_status = gen.get("process_status", None)
@@ -6116,8 +6126,7 @@ def has_video_file_extension(filename):
def has_image_file_extension(filename):
extension = os.path.splitext(filename)[-1]
- return extension in [".jpeg", ".jpg", ".png", ".webp", ".bmp", ".tiff"]
-
+ return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
gen = get_gen_info(state)
if files_to_load == None:
@@ -7016,10 +7025,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
any_reference_image = False
v2i_switch_supported = (vace or t2v or standin) and not image_outputs
ti2v_2_2 = base_model_type in ["ti2v_2_2"]
-
+ gallery_height = 350
def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ):
with gr.Row(visible = visible) as gallery_row:
- gallery_amg = AdvancedMediaGallery(media_mode="image", height=None, columns=4, label=label, initial = value , single_image_mode = single_image_mode )
+ gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode )
gallery_amg.mount(update_form=update_form)
return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements()
@@ -7044,7 +7053,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
- video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None))
+ video_source = gr.Video(label= "Video Source", height = gallery_height, visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None))
model_mode = gr.Dropdown(visible = False)
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
any_video_source = True
@@ -7062,7 +7071,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value )
- video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
+ video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
if not diffusion_forcing:
model_mode = gr.Dropdown(
choices=[
@@ -7084,7 +7093,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V")
image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
- video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),)
+ video_source = gr.Video(label= "Video Source", height = gallery_height, visible = True, value= ui_defaults.get("video_source", None),)
model_mode = gr.Dropdown(
choices=[
("Pan Right", 1),
@@ -7121,7 +7130,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
if hunyuan_i2v:
video_source = gr.Video(value=None, visible=False)
else:
- video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
+ video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
else:
image_prompt_type = gr.Radio(choices=[("", "")], value="")
image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
@@ -7311,8 +7320,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
visible = False,
label="Start / Reference Images", scale = 2
)
- image_guide = gr.Image(label= "Control Image", type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None))
- video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None))
+ image_guide = gr.Image(label= "Control Image", height = gallery_height, type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None))
+ video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None))
denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False)
keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False)
@@ -7331,8 +7340,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False)
video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False)
any_image_mask = image_outputs and vace
- image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("image_mask", None))
- video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None))
+ image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None))
+ video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("video_mask", None))
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image
From a60eea23713c41dc89d7287383d5bf025334f494 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 3 Sep 2025 19:39:17 +0200
Subject: [PATCH 011/155] fixes and polish
---
models/flux/flux_handler.py | 21 +++++++++---
models/flux/flux_main.py | 67 ++++++++++++++++++++++++++++++++-----
models/flux/sampling.py | 1 -
models/qwen/qwen_handler.py | 1 +
models/wan/wan_handler.py | 2 ++
shared/utils/utils.py | 6 ++--
wgp.py | 35 ++++++++++++-------
7 files changed, 105 insertions(+), 28 deletions(-)
diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py
index 162ec4c69..b8b7b9b42 100644
--- a/models/flux/flux_handler.py
+++ b/models/flux/flux_handler.py
@@ -26,12 +26,13 @@ def query_model_def(base_model_type, model_def):
model_def_output["no_background_removal"] = True
model_def_output["image_ref_choices"] = {
- "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"),
- ("Up to two Images are Style Images", "IJ")],
- "default": "I",
- "letters_filter": "IJ",
+ "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
+ ("Up to two Images are Style Images", "KIJ")],
+ "default": "KI",
+ "letters_filter": "KIJ",
"label": "Reference Images / Style Images"
}
+ model_def_output["lock_image_refs_ratios"] = True
return model_def_output
@@ -107,6 +108,16 @@ def load_model(model_filename, model_type, base_model_type, model_def, quantizeT
pipe["feature_embedder"] = flux_model.feature_embedder
return flux_model, pipe
+ @staticmethod
+ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
+ flux_model = model_def.get("flux-model", "flux-dev")
+ flux_uso = flux_model == "flux-dev-uso"
+ if flux_uso and settings_version < 2.29:
+ video_prompt_type = ui_defaults.get("video_prompt_type", "")
+ if "I" in video_prompt_type:
+ video_prompt_type = video_prompt_type.replace("I", "KI")
+ ui_defaults["video_prompt_type"] = video_prompt_type
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
flux_model = model_def.get("flux-model", "flux-dev")
@@ -116,6 +127,6 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
})
if model_def.get("reference_image", False):
ui_defaults.update({
- "video_prompt_type": "I" if flux_uso else "KI",
+ "video_prompt_type": "KI",
})
diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py
index 55a2b910a..9bb8e7379 100644
--- a/models/flux/flux_main.py
+++ b/models/flux/flux_main.py
@@ -9,6 +9,9 @@
from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack
from .modules.layers import get_linear_split_map
from transformers import SiglipVisionModel, SiglipImageProcessor
+import torchvision.transforms.functional as TVF
+import math
+from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
from .util import (
aspect_ratio_to_height_width,
@@ -21,6 +24,44 @@
from PIL import Image
+def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
+ target_height_ref1 = int(target_height_ref1 // 64 * 64)
+ target_width_ref1 = int(target_width_ref1 // 64 * 64)
+ h, w = image.shape[-2:]
+ if h < target_height_ref1 or w < target_width_ref1:
+ # 计算长宽比
+ aspect_ratio = w / h
+ if h < target_height_ref1:
+ new_h = target_height_ref1
+ new_w = new_h * aspect_ratio
+ if new_w < target_width_ref1:
+ new_w = target_width_ref1
+ new_h = new_w / aspect_ratio
+ else:
+ new_w = target_width_ref1
+ new_h = new_w / aspect_ratio
+ if new_h < target_height_ref1:
+ new_h = target_height_ref1
+ new_w = new_h * aspect_ratio
+ else:
+ aspect_ratio = w / h
+ tgt_aspect_ratio = target_width_ref1 / target_height_ref1
+ if aspect_ratio > tgt_aspect_ratio:
+ new_h = target_height_ref1
+ new_w = new_h * aspect_ratio
+ else:
+ new_w = target_width_ref1
+ new_h = new_w / aspect_ratio
+ # 使用 TVF.resize 进行图像缩放
+ image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
+ # 计算中心裁剪的参数
+ top = (image.shape[-2] - target_height_ref1) // 2
+ left = (image.shape[-1] - target_width_ref1) // 2
+ # 使用 TVF.crop 进行中心裁剪
+ image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
+ return image
+
+
def stitch_images(img1, img2):
# Resize img2 to match img1's height
width1, height1 = img1.size
@@ -129,11 +170,11 @@ def generate(
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso']
- image_stiching = not self.name in ['flux-dev-uso']
-
+ image_stiching = not self.name in ['flux-dev-uso'] #and False
+ # image_refs_relative_size = 100
+ crop = False
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
ref_style_imgs = []
-
if "I" in video_prompt_type and len(input_ref_images) > 0:
if flux_dev_uso :
if "J" in video_prompt_type:
@@ -148,7 +189,7 @@ def generate(
if "K" in video_prompt_type :
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
-
+ # actual rescale will happen in prepare_kontext
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
@@ -157,14 +198,24 @@ def generate(
if "K" in video_prompt_type:
# image latents tiling method
w, h = input_ref_images[0].size
- height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
- input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
+ if crop :
+ img = convert_image_to_tensor(input_ref_images[0])
+ img = resize_and_centercrop_image(img, height, width)
+ input_ref_images[0] = convert_tensor_to_image(img)
+ else:
+ height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
+ input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
first_ref = 1
for i in range(first_ref,len(input_ref_images)):
w, h = input_ref_images[i].size
- image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
- input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
+ if crop:
+ img = convert_image_to_tensor(input_ref_images[i])
+ img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100))
+ input_ref_images[i] = convert_tensor_to_image(img)
+ else:
+ image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
+ input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else:
input_ref_images = None
diff --git a/models/flux/sampling.py b/models/flux/sampling.py
index 5534e9f6b..f43ae15e0 100644
--- a/models/flux/sampling.py
+++ b/models/flux/sampling.py
@@ -153,7 +153,6 @@ def prepare_kontext(
# Kontext is trained on specific resolutions, using one of them is recommended
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
-
width = 2 * int(width / 16)
height = 2 * int(height / 16)
diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py
index c6004e164..6fc488aec 100644
--- a/models/qwen/qwen_handler.py
+++ b/models/qwen/qwen_handler.py
@@ -15,6 +15,7 @@ def query_model_def(base_model_type, model_def):
("Default", "default"),
("Lightning", "lightning")],
"guidance_max_phases" : 1,
+ "lock_image_refs_ratios": True,
}
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 6d91fe243..9adc3a884 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -117,6 +117,8 @@ def query_model_def(base_model_type, model_def):
extra_model_def["no_background_removal"] = True
# extra_model_def["at_least_one_image_ref_needed"] = True
+ if base_model_type in ["standin"] or vace_class:
+ extra_model_def["lock_image_refs_ratios"] = True
# if base_model_type in ["phantom_1.3B", "phantom_14B"]:
# extra_model_def["one_image_ref_needed"] = True
diff --git a/shared/utils/utils.py b/shared/utils/utils.py
index a55807a17..7ddf1eb6e 100644
--- a/shared/utils/utils.py
+++ b/shared/utils/utils.py
@@ -18,11 +18,11 @@
import tempfile
import subprocess
import json
-
+from functools import lru_cache
from PIL import Image
-
+video_info_cache = []
def seed_everything(seed: int):
random.seed(seed)
np.random.seed(seed)
@@ -77,7 +77,9 @@ def truncate_for_filesystem(s, max_bytes=255):
else: r = m - 1
return s[:l]
+@lru_cache(maxsize=100)
def get_video_info(video_path):
+ global video_info_cache
import cv2
cap = cv2.VideoCapture(video_path)
diff --git a/wgp.py b/wgp.py
index f11afa94f..a3cc51010 100644
--- a/wgp.py
+++ b/wgp.py
@@ -60,8 +60,8 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.5.12"
-WanGP_version = "8.32"
-settings_version = 2.28
+WanGP_version = "8.33"
+settings_version = 2.29
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -3313,6 +3313,7 @@ def check(src, cond):
if not all_letters(src, pos): return False
if neg is not None and any_letters(src, neg): return False
return True
+ image_outputs = configs.get("image_mode",0) == 1
map_video_prompt = {"V" : "Control Video", ("VA", "U") : "Mask Video", "I" : "Reference Images"}
map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"}
map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"}
@@ -3364,6 +3365,7 @@ def check(src, cond):
if multiple_submodels:
video_guidance_scale += f" + Model Switch at {video_switch_threshold if video_model_switch_phase ==1 else video_switch_threshold2}"
video_flow_shift = configs.get("flow_shift", None)
+ if image_outputs: video_flow_shift = None
video_video_guide_outpainting = configs.get("video_guide_outpainting", "")
video_outpainting = ""
if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \
@@ -4545,7 +4547,8 @@ def remove_temp_filenames(temp_filenames_list):
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from shared.utils.utils import resize_and_remove_background
- image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace or standin) ) # no fit for vace ref images as it is done later
+ # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
+ image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or model_def.get("lock_image_refs_ratios", False)) ) # no fit for vace ref images as it is done later
update_task_thumbnails(task, locals())
send_cmd("output")
joint_pass = boost ==1 #and profile != 1 and profile != 3
@@ -5912,6 +5915,8 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if target == "settings":
return inputs
+ image_outputs = inputs.get("image_mode",0) == 1
+
pop=[]
if "force_fps" in inputs and len(inputs["force_fps"])== 0:
pop += ["force_fps"]
@@ -5977,7 +5982,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if guidance_max_phases < 3 or guidance_phases < 3:
pop += ["guidance3_scale", "switch_threshold2", "model_switch_phase"]
- if ltxv:
+ if ltxv or image_outputs:
pop += ["flow_shift"]
if model_def.get("no_negative_prompt", False) :
@@ -6876,11 +6881,15 @@ def detect_auto_save_form(state, evt:gr.SelectData):
return gr.update()
def compute_video_length_label(fps, current_video_length):
- return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s",
+ if fps is None:
+ return f"Number of frames"
+ else:
+ return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s",
-def refresh_video_length_label(state, current_video_length):
- fps = get_model_fps(get_base_model_type(state["model_type"]))
- return gr.update(label= compute_video_length_label(fps, current_video_length))
+def refresh_video_length_label(state, current_video_length, force_fps, video_guide, video_source):
+ base_model_type = get_base_model_type(state["model_type"])
+ computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source )
+ return gr.update(label= compute_video_length_label(computed_fps, current_video_length))
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None):
global inputs_names #, advanced
@@ -7469,8 +7478,9 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
+ computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , video_guide, video_source )
video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length,
- step=frames_step, label=compute_video_length_label(fps, current_video_length) , visible = True, interactive= True)
+ step=frames_step, label=compute_video_length_label(computed_fps, current_video_length) , visible = True, interactive= True)
with gr.Row(visible = not lock_inference_steps) as inference_steps_row:
num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True)
@@ -7643,7 +7653,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row:
- gr.Markdown("You may transfer the exising audio tracks of a Control Video")
+ gr.Markdown("You may transfer the existing audio tracks of a Control Video")
audio_prompt_type_remux = gr.Dropdown(
choices=[
("No Remux", ""),
@@ -7955,7 +7965,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab,
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col,
- video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row,
+ video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row,
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
@@ -7975,7 +7985,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
resolution.change(fn=record_last_resolution, inputs=[state, resolution])
- video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" )
+ # video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" )
+ gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last" )
guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
From e5abb1b9bc27cfe7f05562651aa23f14bba3d5e9 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 4 Sep 2025 01:10:42 +0200
Subject: [PATCH 012/155] fixed pytorch compilation
---
models/wan/any2video.py | 7 ++++---
requirements.txt | 2 +-
wgp.py | 6 +++---
3 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 8eb2ffef2..e7d54ef29 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -454,7 +454,8 @@ def generate(self,
timesteps.append(0.)
timesteps = [torch.tensor([t], device=self.device) for t in timesteps]
if self.use_timestep_transform:
- timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1]
+ timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1]
+ timesteps = torch.tensor(timesteps)
sample_scheduler = None
elif sample_solver == 'causvid':
sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True)
@@ -1016,8 +1017,8 @@ def clear():
if sample_solver == "euler":
dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1])
- dt = dt / self.num_timesteps
- latents = latents - noise_pred * dt[:, None, None, None, None]
+ dt = dt.item() / self.num_timesteps
+ latents = latents - noise_pred * dt
else:
latents = sample_scheduler.step(
noise_pred[:, :, :target_shape[1]],
diff --git a/requirements.txt b/requirements.txt
index cdc21ad53..0ae0ab32f 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -44,7 +44,7 @@ pydantic==2.10.6
# Math & modeling
torchdiffeq>=0.2.5
tensordict>=0.6.1
-mmgp==3.5.12
+mmgp==3.6.0
peft==0.15.0
matplotlib
diff --git a/wgp.py b/wgp.py
index a3cc51010..c57cc48cd 100644
--- a/wgp.py
+++ b/wgp.py
@@ -59,8 +59,8 @@
AUTOSAVE_FILENAME = "queue.zip"
PROMPT_VARS_MAX = 10
-target_mmgp_version = "3.5.12"
-WanGP_version = "8.33"
+target_mmgp_version = "3.6.0"
+WanGP_version = "8.34"
settings_version = 2.29
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -8509,7 +8509,7 @@ def check(mode):
("Off", "" ),
],
value= compile,
- label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)",
+ label="Compile Transformer : up to 10-20% faster, useful only if multiple gens at same frames no / resolution",
interactive= not lock_ui_compile
)
From 56a51b79f00a2cd370b179412618c3167fae8e10 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 4 Sep 2025 01:25:19 +0200
Subject: [PATCH 013/155] add multitalk support for audio in mp4
---
models/wan/multitalk/multitalk.py | 25 ++++++++++++++++++++++++-
1 file changed, 24 insertions(+), 1 deletion(-)
diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py
index fbf9175ae..b04f65f32 100644
--- a/models/wan/multitalk/multitalk.py
+++ b/models/wan/multitalk/multitalk.py
@@ -59,7 +59,30 @@ def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=160
audio_emb = audio_emb.cpu().detach()
return audio_emb
-
+
+def extract_audio_from_video(filename, sample_rate):
+ raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav'
+ ffmpeg_command = [
+ "ffmpeg",
+ "-y",
+ "-i",
+ str(filename),
+ "-vn",
+ "-acodec",
+ "pcm_s16le",
+ "-ar",
+ "16000",
+ "-ac",
+ "2",
+ str(raw_audio_path),
+ ]
+ subprocess.run(ffmpeg_command, check=True)
+ human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate)
+ human_speech_array = loudness_norm(human_speech_array, sr)
+ os.remove(raw_audio_path)
+
+ return human_speech_array
+
def audio_prepare_single(audio_path, sample_rate=16000, duration = 0):
ext = os.path.splitext(audio_path)[1].lower()
if ext in ['.mp4', '.mov', '.avi', '.mkv']:
From 13b001d4eaac075973bad596b119dfdce47d550c Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 4 Sep 2025 11:42:09 +0200
Subject: [PATCH 014/155] fixed forms errors
---
wgp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/wgp.py b/wgp.py
index c57cc48cd..653683c15 100644
--- a/wgp.py
+++ b/wgp.py
@@ -7478,7 +7478,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
- computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , video_guide, video_source )
+ computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , ui_defaults.get("video_guide", None), ui_defaults.get("video_source", None))
video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length,
step=frames_step, label=compute_video_length_label(computed_fps, current_video_length) , visible = True, interactive= True)
From b37c1a26a8395d090ecf0c5e07d59db3d7b05cf4 Mon Sep 17 00:00:00 2001
From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com>
Date: Fri, 5 Sep 2025 09:12:01 +0200
Subject: [PATCH 015/155] removed infinitetalk forced video to video
---
models/wan/any2video.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index e7d54ef29..d394449dc 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -537,7 +537,6 @@ def generate(self,
image_ref = input_frames[:, 0]
if input_video is None: input_video = input_frames[:, 0:1]
new_shot = "Q" in video_prompt_type
- denoising_strength = 0.5
else:
if pre_video_frame is None:
new_shot = True
From 8859e816d0a9969b4a1207f3eaf7333597cf1b26 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 5 Sep 2025 14:17:21 +0200
Subject: [PATCH 016/155] Take me to outer space
---
README.md | 17 ++
models/hyvideo/hunyuan_handler.py | 4 +
models/ltx_video/ltxv_handler.py | 1 +
models/wan/any2video.py | 117 ++++++--------
models/wan/df_handler.py | 12 ++
models/wan/wan_handler.py | 54 +++++++
wgp.py | 256 ++++++++++++++----------------
7 files changed, 256 insertions(+), 205 deletions(-)
diff --git a/README.md b/README.md
index 77f7fa895..eb497f1a9 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,23 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
+### September 5 2025: WanGP v8.4 - Take me to Outer Space
+You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie.
+
+I have made it easier to do just that with *Qwen Edit* and *Wan*:
+- **End Frames can now be combined with Continue a Video** (and not just a Start Frame)
+- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window
+
+You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*)
+
+The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement.
+
+This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**.
+So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close)
+
+
+You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens).
+
### September 2 2025: WanGP v8.31 - At last the pain stops
- This single new feature should give you the strength to face all the potential bugs of this new release:
diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py
index da60f72cb..9cbaea7d1 100644
--- a/models/hyvideo/hunyuan_handler.py
+++ b/models/hyvideo/hunyuan_handler.py
@@ -56,6 +56,10 @@ def query_model_def(base_model_type, model_def):
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
extra_model_def["one_image_ref_needed"] = True
+
+ if base_model_type in ["hunyuan_i2v"]:
+ extra_model_def["image_prompt_types_allowed"] = "S"
+
return extra_model_def
@staticmethod
diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py
index d35bcd457..c89b69a43 100644
--- a/models/ltx_video/ltxv_handler.py
+++ b/models/ltx_video/ltxv_handler.py
@@ -24,6 +24,7 @@ def query_model_def(base_model_type, model_def):
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 8
extra_model_def["sliding_window"] = True
+ extra_model_def["image_prompt_types_allowed"] = "TSEV"
return extra_model_def
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index e7d54ef29..bb91dc66b 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -537,7 +537,6 @@ def generate(self,
image_ref = input_frames[:, 0]
if input_video is None: input_video = input_frames[:, 0:1]
new_shot = "Q" in video_prompt_type
- denoising_strength = 0.5
else:
if pre_video_frame is None:
new_shot = True
@@ -556,74 +555,59 @@ def generate(self,
_ , preframes_count, height, width = input_video.shape
input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
if infinitetalk:
- image_for_clip = image_ref.to(input_video)
+ image_start = image_ref.to(input_video)
control_pre_frames_count = 1
- control_video = image_for_clip.unsqueeze(1)
+ control_video = image_start.unsqueeze(1)
else:
- image_for_clip = input_video[:, -1]
+ image_start = input_video[:, -1]
control_pre_frames_count = preframes_count
control_video = input_video
- lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
- if hasattr(self, "clip"):
- clip_image_size = self.clip.model.image_size
- clip_image = resize_lanczos(image_for_clip, clip_image_size, clip_image_size)[:, None, :, :]
- clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ])
- clip_image = None
- else:
- clip_context = None
- enc = torch.concat( [control_video, torch.zeros( (3, frame_num-control_pre_frames_count, height, width),
- device=self.device, dtype= self.VAE_dtype)],
- dim = 1).to(self.device)
- color_reference_frame = image_for_clip.unsqueeze(1).clone()
+
+ color_reference_frame = image_start.unsqueeze(1).clone()
else:
- preframes_count = control_pre_frames_count = 1
- any_end_frame = image_end is not None
- add_frames_for_end_image = any_end_frame and model_type == "i2v"
- if any_end_frame:
- if add_frames_for_end_image:
- frame_num +=1
- lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
- trim_frames = 1
-
+ preframes_count = control_pre_frames_count = 1
height, width = image_start.shape[1:]
+ control_video = image_start.unsqueeze(1).to(self.device)
+ color_reference_frame = control_video.clone()
- lat_h = round(
- height // self.vae_stride[1] //
- self.patch_size[1] * self.patch_size[1])
- lat_w = round(
- width // self.vae_stride[2] //
- self.patch_size[2] * self.patch_size[2])
- height = lat_h * self.vae_stride[1]
- width = lat_w * self.vae_stride[2]
- image_start_frame = image_start.unsqueeze(1).to(self.device)
- color_reference_frame = image_start_frame.clone()
- if image_end is not None:
- img_end_frame = image_end.unsqueeze(1).to(self.device)
-
- if hasattr(self, "clip"):
- clip_image_size = self.clip.model.image_size
- image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
- if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size)
- if model_type == "flf2v_720p":
- clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
- else:
- clip_context = self.clip.visual([image_start[:, None, :, :]])
- else:
- clip_context = None
-
- if any_end_frame:
- enc= torch.concat([
- image_start_frame,
- torch.zeros( (3, frame_num-2, height, width), device=self.device, dtype= self.VAE_dtype),
- img_end_frame,
- ], dim=1).to(self.device)
+ any_end_frame = image_end is not None
+ add_frames_for_end_image = any_end_frame and model_type == "i2v"
+ if any_end_frame:
+ color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
+ if add_frames_for_end_image:
+ frame_num +=1
+ lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2)
+ trim_frames = 1
+
+ lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
+
+ if image_end is not None:
+ img_end_frame = image_end.unsqueeze(1).to(self.device)
+
+ if hasattr(self, "clip"):
+ clip_image_size = self.clip.model.image_size
+ image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
+ image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start
+ if model_type == "flf2v_720p":
+ clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
else:
- enc= torch.concat([
- image_start_frame,
- torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype)
- ], dim=1).to(self.device)
+ clip_context = self.clip.visual([image_start[:, None, :, :]])
+ else:
+ clip_context = None
+
+ if any_end_frame:
+ enc= torch.concat([
+ control_video,
+ torch.zeros( (3, frame_num-control_pre_frames_count-1, height, width), device=self.device, dtype= self.VAE_dtype),
+ img_end_frame,
+ ], dim=1).to(self.device)
+ else:
+ enc= torch.concat([
+ control_video,
+ torch.zeros( (3, frame_num-control_pre_frames_count, height, width), device=self.device, dtype= self.VAE_dtype)
+ ], dim=1).to(self.device)
- image_start = image_end = image_start_frame = img_end_frame = image_for_clip = image_ref = None
+ image_start = image_end = img_end_frame = image_ref = control_video = None
msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device)
if any_end_frame:
@@ -657,12 +641,11 @@ def generate(self,
# Recam Master
if recam:
- # should be be in fact in input_frames since it is control video not a video to be extended
target_camera = model_mode
- height,width = input_video.shape[-2:]
- input_video = input_video.to(dtype=self.dtype , device=self.device)
- source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
- del input_video
+ height,width = input_frames.shape[-2:]
+ input_frames = input_frames.to(dtype=self.dtype , device=self.device)
+ source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
+ del input_frames
# Process target camera (recammaster)
from shared.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
@@ -754,7 +737,9 @@ def generate(self,
else:
target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
- if multitalk and audio_proj != None:
+ if multitalk:
+ if audio_proj is None:
+ audio_proj = [ torch.zeros( (1, 1, 5, 12, 768 ), dtype=self.dtype, device=self.device), torch.zeros( (1, (frame_num - 1) // 4, 8, 12, 768 ), dtype=self.dtype, device=self.device) ]
from .multitalk.multitalk import get_target_masks
audio_proj = [audio.to(self.dtype) for audio in audio_proj]
human_no = len(audio_proj[0])
diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py
index bc79e2ebc..82e704aaa 100644
--- a/models/wan/df_handler.py
+++ b/models/wan/df_handler.py
@@ -26,6 +26,18 @@ def query_model_def(base_model_type, model_def):
extra_model_def["tea_cache"] = True
extra_model_def["guidance_max_phases"] = 1
+ extra_model_def["model_modes"] = {
+ "choices": [
+ ("Synchronous", 0),
+ ("Asynchronous (better quality but around 50% extra steps added)", 5),
+ ],
+ "default": 0,
+ "label" : "Generation Type"
+ }
+
+ extra_model_def["image_prompt_types_allowed"] = "TSEV"
+
+
return extra_model_def
@staticmethod
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 9adc3a884..3a5dd64a6 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -5,6 +5,9 @@
def test_class_i2v(base_model_type):
return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
+def text_oneframe_overlap(base_model_type):
+ return test_class_i2v(base_model_type) and not test_multitalk(base_model_type)
+
def test_class_1_3B(base_model_type):
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
@@ -120,6 +123,37 @@ def query_model_def(base_model_type, model_def):
if base_model_type in ["standin"] or vace_class:
extra_model_def["lock_image_refs_ratios"] = True
+ if base_model_type in ["recam_1.3B"]:
+ extra_model_def["keep_frames_video_guide_not_supported"] = True
+ extra_model_def["model_modes"] = {
+ "choices": [
+ ("Pan Right", 1),
+ ("Pan Left", 2),
+ ("Tilt Up", 3),
+ ("Tilt Down", 4),
+ ("Zoom In", 5),
+ ("Zoom Out", 6),
+ ("Translate Up (with rotation)", 7),
+ ("Translate Down (with rotation)", 8),
+ ("Arc Left (with rotation)", 9),
+ ("Arc Right (with rotation)", 10),
+ ],
+ "default": 1,
+ "label" : "Camera Movement Type"
+ }
+ if vace_class or base_model_type in ["infinitetalk"]:
+ image_prompt_types_allowed = "TVL"
+ elif base_model_type in ["ti2v_2_2"]:
+ image_prompt_types_allowed = "TSEVL"
+ elif i2v:
+ image_prompt_types_allowed = "SEVL"
+ else:
+ image_prompt_types_allowed = ""
+ extra_model_def["image_prompt_types_allowed"] = image_prompt_types_allowed
+
+ if text_oneframe_overlap(base_model_type):
+ extra_model_def["sliding_window_defaults"] = { "overlap_min" : 1, "overlap_max" : 1, "overlap_step": 0, "overlap_default": 1}
+
# if base_model_type in ["phantom_1.3B", "phantom_14B"]:
# extra_model_def["one_image_ref_needed"] = True
@@ -251,6 +285,17 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
video_prompt_type = video_prompt_type.replace("U", "RU")
ui_defaults["video_prompt_type"] = video_prompt_type
+ if settings_version < 2.31:
+ if base_model_type in "recam_1.3B":
+ video_prompt_type = ui_defaults.get("video_prompt_type", "")
+ if not "V" in video_prompt_type:
+ video_prompt_type += "UV"
+ ui_defaults["video_prompt_type"] = video_prompt_type
+ ui_defaults["image_prompt_type"] = ""
+
+ if text_oneframe_overlap(base_model_type):
+ ui_defaults["sliding_window_overlap"] = 1
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
@@ -309,6 +354,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"image_prompt_type": "T",
})
+ if base_model_type in ["recam_1.3B"]:
+ ui_defaults.update({
+ "video_prompt_type": "UV",
+ })
+
+ if text_oneframe_overlap(base_model_type):
+ ui_defaults.update["sliding_window_overlap"] = 1
+ ui_defaults.update["color_correction_strength"]= 0
+
if test_multitalk(base_model_type):
ui_defaults["audio_guidance_scale"] = 4
diff --git a/wgp.py b/wgp.py
index 653683c15..396e273e2 100644
--- a/wgp.py
+++ b/wgp.py
@@ -60,8 +60,8 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.34"
-settings_version = 2.29
+WanGP_version = "8.4"
+settings_version = 2.31
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -347,7 +347,7 @@ def process_prompt_and_add_tasks(state, model_choice):
model_switch_phase = inputs["model_switch_phase"]
switch_threshold = inputs["switch_threshold"]
switch_threshold2 = inputs["switch_threshold2"]
-
+ multi_prompts_gen_type = inputs["multi_prompts_gen_type"]
if len(loras_multipliers) > 0:
_, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases)
@@ -445,7 +445,7 @@ def process_prompt_and_add_tasks(state, model_choice):
if "I" in video_prompt_type:
if image_refs == None or len(image_refs) == 0:
- gr.Info("You must provide at least one Refererence Image")
+ gr.Info("You must provide at least one Reference Image")
return
image_refs = clean_image_list(image_refs)
if image_refs == None :
@@ -511,9 +511,14 @@ def process_prompt_and_add_tasks(state, model_choice):
if image_start == None :
gr.Info("Start Image should be an Image")
return
+ if multi_prompts_gen_type == 1 and len(image_start) > 1:
+ gr.Info("Only one Start Image is supported")
+ return
else:
image_start = None
+ if not any_letters(image_prompt_type, "SVL"):
+ image_prompt_type = image_prompt_type.replace("E", "")
if "E" in image_prompt_type:
if image_end == None or isinstance(image_end, list) and len(image_end) == 0:
gr.Info("You must provide an End Image")
@@ -522,32 +527,35 @@ def process_prompt_and_add_tasks(state, model_choice):
if image_end == None :
gr.Info("End Image should be an Image")
return
- if len(image_start) != len(image_end):
- gr.Info("The number of Start and End Images should be the same ")
- return
+ if multi_prompts_gen_type == 0:
+ if video_source is not None:
+ if len(image_end)> 1:
+ gr.Info("If a Video is to be continued and the option 'Each Text Prompt Will create a new generated Video' is set, there can be only one End Image")
+ return
+ elif len(image_start or []) != len(image_end or []):
+ gr.Info("The number of Start and End Images should be the same when the option 'Each Text Prompt Will create a new generated Video'")
+ return
else:
image_end = None
if test_any_sliding_window(model_type) and image_mode == 0:
if video_length > sliding_window_size:
- full_video_length = video_length if video_source is None else video_length + sliding_window_overlap
+ full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1
extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation"
no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap)
gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated")
if "recam" in model_filename:
- if video_source == None:
- gr.Info("You must provide a Source Video")
+ if video_guide == None:
+ gr.Info("You must provide a Control Video")
return
-
- frames = get_resampled_video(video_source, 0, 81, get_computed_fps(force_fps, model_type , video_guide, video_source ))
+ computed_fps = get_computed_fps(force_fps, model_type , video_guide, video_source )
+ frames = get_resampled_video(video_guide, 0, 81, computed_fps)
if len(frames)<81:
- gr.Info("Recammaster source video should be at least 81 frames once the resampling at 16 fps has been done")
+ gr.Info(f"Recammaster Control video should be at least 81 frames once the resampling at {computed_fps} fps has been done")
return
-
-
if "hunyuan_custom_custom_edit" in model_filename:
if len(keep_frames_video_guide) > 0:
gr.Info("Filtering Frames with this model is not supported")
@@ -558,13 +566,13 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows")
return
- if image_end != None and len(image_end) > 1:
- gr.Info("Only one End Image must be provided if multiple prompts are used for different windows")
- return
+ # if image_end != None and len(image_end) > 1:
+ # gr.Info("Only one End Image must be provided if multiple prompts are used for different windows")
+ # return
override_inputs = {
"image_start": image_start[0] if image_start !=None and len(image_start) > 0 else None,
- "image_end": image_end[0] if image_end !=None and len(image_end) > 0 else None,
+ "image_end": image_end, #[0] if image_end !=None and len(image_end) > 0 else None,
"image_refs": image_refs,
"audio_guide": audio_guide,
"audio_guide2": audio_guide2,
@@ -640,19 +648,21 @@ def process_prompt_and_add_tasks(state, model_choice):
override_inputs["prompt"] = single_prompt
inputs.update(override_inputs)
add_video_task(**inputs)
+ new_prompts_count = len(prompts)
else:
+ new_prompts_count = 1
override_inputs["prompt"] = "\n".join(prompts)
inputs.update(override_inputs)
add_video_task(**inputs)
- gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0)
+ gen["prompts_max"] = new_prompts_count + gen.get("prompts_max",0)
state["validate_success"] = 1
queue= gen.get("queue", [])
return update_queue_data(queue)
def get_preview_images(inputs):
- inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ]
- labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"]
+ inputs_to_query = ["image_start", "video_source", "image_end", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ]
+ labels = ["Start Image", "Video Source", "End Image", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"]
start_image_data = None
start_image_labels = []
end_image_data = None
@@ -3454,6 +3464,8 @@ def convert_image(image):
from PIL import ImageOps
from typing import cast
+ if isinstance(image, str):
+ image = Image.open(image)
image = image.convert('RGB')
return cast(Image, ImageOps.exif_transpose(image))
@@ -4506,7 +4518,7 @@ def remove_temp_filenames(temp_filenames_list):
if test_any_sliding_window(model_type) :
if video_source is not None:
- current_video_length += sliding_window_overlap
+ current_video_length += sliding_window_overlap - 1
sliding_window = current_video_length > sliding_window_size
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
else:
@@ -4690,8 +4702,7 @@ def remove_temp_filenames(temp_filenames_list):
while not abort:
enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* get_model_fps(base_model_type)+1) or RIFLEx_setting == 1
- if sliding_window:
- prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
+ prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1]
new_extra_windows = gen.get("extra_windows",0)
gen["extra_windows"] = 0
extra_windows += new_extra_windows
@@ -4722,15 +4733,13 @@ def remove_temp_filenames(temp_filenames_list):
image_start_tensor = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
image_start_tensor = convert_image_to_tensor(image_start_tensor)
pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1)
- if image_end is not None:
- image_end_tensor = image_end.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
- image_end_tensor = convert_image_to_tensor(image_end_tensor)
else:
if "L" in image_prompt_type:
refresh_preview["video_source"] = get_video_frame(video_source, 0)
prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size )
prefix_video = prefix_video.permute(3, 0, 1, 2)
prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w
+ new_height, new_width = prefix_video.shape[-2:]
pre_video_guide = prefix_video[:, -reuse_frames:]
pre_video_frame = convert_tensor_to_image(prefix_video[:, -1])
source_video_overlap_frames_count = pre_video_guide.shape[1]
@@ -4739,7 +4748,14 @@ def remove_temp_filenames(temp_filenames_list):
image_size = pre_video_guide.shape[-2:]
sample_fit_canvas = None
guide_start_frame = prefix_video.shape[1]
-
+ if image_end is not None:
+ image_end_list= image_end if isinstance(image_end, list) else [image_end]
+ if len(image_end_list) >= window_no:
+ new_height, new_width = image_size
+ image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ image_end_tensor = convert_image_to_tensor(image_end_tensor)
+ image_end_list= None
+
window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count)
guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames)
alignment_shift = source_video_frames_count if reset_control_aligment else 0
@@ -4797,7 +4813,7 @@ def remove_temp_filenames(temp_filenames_list):
image_size = src_video.shape[-2:]
sample_fit_canvas = None
- elif "G" in video_prompt_type: # video to video
+ else: # video to video
video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps)
if video_guide_processed is None:
src_video = pre_video_guide
@@ -5298,7 +5314,7 @@ def process_tasks(state):
while True:
with gen_lock:
process_status = gen.get("process_status", None)
- if process_status is None:
+ if process_status is None or process_status == "process:main":
gen["process_status"] = "process:main"
break
time.sleep(1)
@@ -6570,11 +6586,13 @@ def any_letters(source_str, letters):
return True
return False
-def filter_letters(source_str, letters):
+def filter_letters(source_str, letters, default= ""):
ret = ""
for letter in letters:
if letter in source_str:
ret += letter
+ if len(ret) == 0:
+ return default
return ret
def add_to_sequence(source_str, letters):
@@ -6601,9 +6619,18 @@ def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_typ
audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources)
return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type))
-def refresh_image_prompt_type(state, image_prompt_type):
- any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0
- return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source)
+def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_radio):
+ image_prompt_type = del_in_sequence(image_prompt_type, "VLTS")
+ image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio)
+ any_video_source = len(filter_letters(image_prompt_type, "VL"))>0
+ end_visible = any_letters(image_prompt_type, "SVL")
+ return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible)
+
+def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox):
+ image_prompt_type = del_in_sequence(image_prompt_type, "E")
+ if end_checkbox: image_prompt_type += "E"
+ image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio)
+ return image_prompt_type, gr.update(visible = "E" in image_prompt_type )
def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs):
model_type = state["model_type"]
@@ -6680,9 +6707,12 @@ def get_prompt_labels(multi_prompts_gen_type, image_outputs = False):
new_line_text = "each new line of prompt will be used for a window" if multi_prompts_gen_type != 0 else "each new line of prompt will generate " + ("a new image" if image_outputs else "a new video")
return "Prompts (" + new_line_text + ", # lines = comments, ! lines = macros)", "Prompts (" + new_line_text + ", # lines = comments)"
+def get_image_end_label(multi_prompts_gen_type):
+ return "Images as ending points for new Videos in the Generation Queue" if multi_prompts_gen_type == 0 else "Images as ending points for each new Window of the same Video Generation"
+
def refresh_prompt_labels(multi_prompts_gen_type, image_mode):
prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode == 1)
- return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label)
+ return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label), gr.update(label=get_image_end_label(multi_prompts_gen_type))
def show_preview_column_modal(state, column_no):
column_no = int(column_no)
@@ -7054,101 +7084,46 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"):
pass
- with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column:
- if vace or infinitetalk:
- image_prompt_type_value= ui_defaults.get("image_prompt_type","")
- image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value
- image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3)
-
- image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
- image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
- video_source = gr.Video(label= "Video Source", height = gallery_height, visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None))
- model_mode = gr.Dropdown(visible = False)
- keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
+ image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "")
+ model_mode_choices = model_def.get("model_modes", None)
+ with gr.Column(visible= len(image_prompt_types_allowed)> 0 or model_mode_choices is not None) as image_prompt_column:
+ image_prompt_type_value= ui_defaults.get("image_prompt_type","")
+ image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
+ image_prompt_type_choices = []
+ if "T" in image_prompt_types_allowed:
+ image_prompt_type_choices += [("Text Prompt Only", "")]
+ any_start_image = True
+ if "S" in image_prompt_types_allowed:
+ image_prompt_type_choices += [("Start Video with Image", "S")]
+ any_start_image = True
+ if "V" in image_prompt_types_allowed:
any_video_source = True
-
- elif diffusion_forcing or ltxv or ti2v_2_2:
- image_prompt_type_value= ui_defaults.get("image_prompt_type","T")
- # image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Start and End Video with Images", "SE"), ("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3)
- image_prompt_type_choices = [("Text Prompt Only", "T"),("Start Video with Image", "S")]
- if ltxv:
- image_prompt_type_choices += [("Use both a Start and an End Image", "SE")]
- if sliding_window_enabled:
- any_video_source = True
- image_prompt_type_choices += [("Continue Video", "V")]
- image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3)
-
- image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
- image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value )
- video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
- if not diffusion_forcing:
- model_mode = gr.Dropdown(
- choices=[
- ], value=None,
- visible= False
- )
- else:
- model_mode = gr.Dropdown(
- choices=[
- ("Synchronous", 0),
- ("Asynchronous (better quality but around 50% extra steps added)", 5),
- ],
- value=ui_defaults.get("model_mode", 0),
- label="Generation Type", scale = 3,
- visible= True
- )
- keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" )
- elif recammaster:
- image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V")
- image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
- image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
- video_source = gr.Video(label= "Video Source", height = gallery_height, visible = True, value= ui_defaults.get("video_source", None),)
- model_mode = gr.Dropdown(
- choices=[
- ("Pan Right", 1),
- ("Pan Left", 2),
- ("Tilt Up", 3),
- ("Tilt Down", 4),
- ("Zoom In", 5),
- ("Zoom Out", 6),
- ("Translate Up (with rotation)", 7),
- ("Translate Down (with rotation)", 8),
- ("Arc Left (with rotation)", 9),
- ("Arc Right (with rotation)", 10),
- ],
- value=ui_defaults.get("model_mode", 1),
- label="Camera Movement Type", scale = 3,
- visible= True
- )
- keep_frames_video_source = gr.Text(visible=False)
- else:
- if test_class_i2v(model_type) or hunyuan_i2v:
- # image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" )
- image_prompt_type_value= ui_defaults.get("image_prompt_type","S" )
- image_prompt_type_choices = [("Start Video with Image", "S")]
- image_prompt_type_choices += [("Use both a Start and an End Image", "SE")]
- if not hunyuan_i2v:
- any_video_source = True
- image_prompt_type_choices += [("Continue Video", "V")]
-
- image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3)
- any_start_image = True
- any_end_image = True
- image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
- image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value )
- if hunyuan_i2v:
- video_source = gr.Video(value=None, visible=False)
+ image_prompt_type_choices += [("Continue Video", "V")]
+ if "L" in image_prompt_types_allowed:
+ any_video_source = True
+ image_prompt_type_choices += [("Continue Last Video", "L")]
+ with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group:
+ with gr.Row():
+ image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL")
+ if len(image_prompt_type_choices) > 0:
+ image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value =filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]), label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3)
else:
- video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
- else:
- image_prompt_type = gr.Radio(choices=[("", "")], value="")
- image_start_row, image_start, image_start_extra = get_image_gallery(visible = False )
- image_end_row, image_end, image_end_extra = get_image_gallery(visible = False )
- video_source = gr.Video(value=None, visible=False)
+ image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False)
+ if "E" in image_prompt_types_allowed:
+ image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_value, "SVL") and not image_outputs , scale= 1)
+ any_end_image = True
+ else:
+ image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1)
+ image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new Videos in the Generation Queue", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
+ video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
+ image_end_row, image_end, image_end_extra = get_image_gallery(label= get_image_end_label(ui_defaults.get("multi_prompts_gen_type", 0)), value = ui_defaults.get("image_end", None), visible= any_letters(image_prompt_type_value, "SVL") and ("E" in image_prompt_type_value) )
+ if model_mode_choices is None:
model_mode = gr.Dropdown(value=None, visible=False)
- keep_frames_video_source = gr.Text(visible=False)
+ else:
+ model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=ui_defaults.get("model_mode", model_mode_choices["default"]), label=model_mode_choices["label"], visible=True)
+ keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
- with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column:
+ with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
any_control_video = True
@@ -7208,12 +7183,12 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
value=filter_letters(video_prompt_type_value, "PDSLCMUV"),
label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True,
)
- elif infinitetalk:
- video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False)
+ elif recammaster:
+ video_prompt_type_video_guide = gr.Dropdown(value="UV", choices = [("Control Video","UV")], visible=False)
else:
any_control_video = False
any_control_image = False
- video_prompt_type_video_guide = gr.Dropdown(visible= False)
+ video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False)
if infinitetalk:
video_prompt_type_video_guide_alt = gr.Dropdown(
@@ -7228,6 +7203,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
value=filter_letters(video_prompt_type_value, "RGUVQKI"),
label="Video to Video", scale = 3, visible= True, show_label= False,
)
+ any_control_video = any_control_image = True
else:
video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False)
@@ -7761,7 +7737,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False)
elif ltxv:
sliding_window_size = gr.Slider(41, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size")
- sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
+ sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0)
sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=8, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
@@ -7772,8 +7748,9 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
else: # Vace, Multitalk
+ sliding_window_defaults = model_def.get("sliding_window_defaults", {})
sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
- sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
+ sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)")
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
@@ -7790,13 +7767,13 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
multi_prompts_gen_type = gr.Dropdown(
choices=[
- ("Will create new generated Video", 0),
+ ("Will create a new generated Video added to the Generation Queue", 0),
("Will be used for a new Sliding Window of the same Video Generation", 1),
],
value=ui_defaults.get("multi_prompts_gen_type",0),
visible=True,
scale = 1,
- label="Text Prompts separated by a Carriage Return"
+ label="Images & Text Prompts separated by a Carriage Return" if (any_start_image or any_end_image) else "Text Prompts separated by a Carriage Return"
)
with gr.Tab("Misc.", visible = True) as misc_tab:
@@ -7962,7 +7939,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
hidden_countdown_state = gr.Number(value=-1, visible=False, elem_id="hidden_countdown_state_num")
single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn")
- extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column,
+ extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, image_prompt_type_group, image_prompt_type_radio, image_prompt_type_endcheckbox,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab,
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col,
video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row,
@@ -7981,23 +7958,24 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
target_settings = gr.Text(value = "settings", interactive= False, visible= False)
last_choice = gr.Number(value =-1, interactive= False, visible= False)
- resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution])
+ resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution], show_progress="hidden")
resolution.change(fn=record_last_resolution, inputs=[state, resolution])
# video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" )
- gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last" )
+ gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last", show_progress="hidden" )
guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
- image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start_row, image_end_row, video_source, keep_frames_video_source] )
+ image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" )
+ image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] )
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col])
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ])
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
- multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt])
+ multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden")
video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
video_guide_outpainting_left.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_left,gr.State(2)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
From 99fd9aea32937d039fe54542afc0d50e18239f87 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 5 Sep 2025 15:56:42 +0200
Subject: [PATCH 017/155] fixed settings update
---
models/wan/wan_handler.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 3a5dd64a6..bc2eb396e 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -360,8 +360,8 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
})
if text_oneframe_overlap(base_model_type):
- ui_defaults.update["sliding_window_overlap"] = 1
- ui_defaults.update["color_correction_strength"]= 0
+ ui_defaults["sliding_window_overlap"] = 1
+ ui_defaults["color_correction_strength"]= 0
if test_multitalk(base_model_type):
ui_defaults["audio_guidance_scale"] = 4
From 836777eb55ca7c9a752b9a2703b0a55b75366de9 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 5 Sep 2025 20:57:09 +0200
Subject: [PATCH 018/155] fixed End Frames checkbox not always visible
---
models/wan/wan_handler.py | 2 +-
wgp.py | 5 +++--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index bc2eb396e..8316d6d90 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -144,7 +144,7 @@ def query_model_def(base_model_type, model_def):
if vace_class or base_model_type in ["infinitetalk"]:
image_prompt_types_allowed = "TVL"
elif base_model_type in ["ti2v_2_2"]:
- image_prompt_types_allowed = "TSEVL"
+ image_prompt_types_allowed = "TSEV"
elif i2v:
image_prompt_types_allowed = "SEVL"
else:
diff --git a/wgp.py b/wgp.py
index 396e273e2..1eb3fcae8 100644
--- a/wgp.py
+++ b/wgp.py
@@ -7105,12 +7105,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group:
with gr.Row():
image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL")
+ image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1])
if len(image_prompt_type_choices) > 0:
- image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value =filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]), label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3)
+ image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3)
else:
image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False)
if "E" in image_prompt_types_allowed:
- image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_value, "SVL") and not image_outputs , scale= 1)
+ image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1)
any_end_image = True
else:
image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1)
From cf65889e2e1094902e08434e87b18049b9fe76b7 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 5 Sep 2025 21:59:26 +0200
Subject: [PATCH 019/155] fixed End Frames checkbox not always visible
---
models/wan/wan_handler.py | 8 ++++++++
wgp.py | 6 +++---
2 files changed, 11 insertions(+), 3 deletions(-)
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 8316d6d90..20c3594ce 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -296,11 +296,19 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if text_oneframe_overlap(base_model_type):
ui_defaults["sliding_window_overlap"] = 1
+ if settings_version < 2.32:
+ image_prompt_type = ui_defaults.get("image_prompt_type", "")
+ if test_class_i2v(base_model_type) and len(image_prompt_type) == 0:
+ ui_defaults["image_prompt_type"] = "S"
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"sample_solver": "unipc",
})
+ if test_class_i2v(base_model_type):
+ ui_defaults["image_prompt_type"] = "S"
+
if base_model_type in ["fantasy"]:
ui_defaults.update({
"audio_guidance_scale": 5.0,
diff --git a/wgp.py b/wgp.py
index 1eb3fcae8..124ae5e6c 100644
--- a/wgp.py
+++ b/wgp.py
@@ -61,7 +61,7 @@
target_mmgp_version = "3.6.0"
WanGP_version = "8.4"
-settings_version = 2.31
+settings_version = 2.32
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -6142,11 +6142,11 @@ def eject_video_from_gallery(state, input_file_list, choice):
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
def has_video_file_extension(filename):
- extension = os.path.splitext(filename)[-1]
+ extension = os.path.splitext(filename)[-1].lower()
return extension in [".mp4"]
def has_image_file_extension(filename):
- extension = os.path.splitext(filename)[-1]
+ extension = os.path.splitext(filename)[-1].lower()
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
gen = get_gen_info(state)
From 242ae50d8368ac2da00adefa3365f69ca01ee03a Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 5 Sep 2025 22:05:00 +0200
Subject: [PATCH 020/155] fixed End Frames checkbox not always visible
---
wgp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/wgp.py b/wgp.py
index 124ae5e6c..7db02363f 100644
--- a/wgp.py
+++ b/wgp.py
@@ -7105,8 +7105,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group:
with gr.Row():
image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL")
- image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1])
if len(image_prompt_type_choices) > 0:
+ image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1])
image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3)
else:
image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False)
From 0649686e9d785b734f209967128f0849e160f5f6 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 5 Sep 2025 22:14:12 +0200
Subject: [PATCH 021/155] fixed End Frames checkbox not always visible
---
wgp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/wgp.py b/wgp.py
index 7db02363f..4f1a10dd9 100644
--- a/wgp.py
+++ b/wgp.py
@@ -7105,8 +7105,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group:
with gr.Row():
image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL")
+ image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1] if len(image_prompt_type_choices) > 0 else "")
if len(image_prompt_type_choices) > 0:
- image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1])
image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3)
else:
image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False)
From 66a07db16b841a908008f953849758cf3de6cf37 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sat, 6 Sep 2025 00:08:45 +0200
Subject: [PATCH 022/155] fixed End Frames checkbox not always visible
---
models/wan/wan_handler.py | 4 +++-
wgp.py | 3 ++-
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 20c3594ce..5d5c27c8e 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -144,7 +144,9 @@ def query_model_def(base_model_type, model_def):
if vace_class or base_model_type in ["infinitetalk"]:
image_prompt_types_allowed = "TVL"
elif base_model_type in ["ti2v_2_2"]:
- image_prompt_types_allowed = "TSEV"
+ image_prompt_types_allowed = "TSVL"
+ elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
+ image_prompt_types_allowed = "SVL"
elif i2v:
image_prompt_types_allowed = "SEVL"
else:
diff --git a/wgp.py b/wgp.py
index 4f1a10dd9..9bbf1ff89 100644
--- a/wgp.py
+++ b/wgp.py
@@ -6623,7 +6623,8 @@ def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_
image_prompt_type = del_in_sequence(image_prompt_type, "VLTS")
image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio)
any_video_source = len(filter_letters(image_prompt_type, "VL"))>0
- end_visible = any_letters(image_prompt_type, "SVL")
+ model_def = get_model_def(state["model_type"])
+ end_visible = any_letters(image_prompt_type, "SVL") and "E" in model_def.get("image_prompt_types_allowed","")
return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible)
def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox):
From f9f63cbc79c364ac146d833292f0ab89e199228c Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 9 Sep 2025 21:41:35 +0200
Subject: [PATCH 023/155] intermediate commit
---
defaults/flux_dev_kontext.json | 2 -
defaults/flux_dev_uso.json | 2 -
defaults/qwen_image_edit_20B.json | 4 +-
models/flux/flux_handler.py | 37 +-
models/flux/flux_main.py | 69 +--
models/hyvideo/hunyuan.py | 5 -
models/hyvideo/hunyuan_handler.py | 21 +
models/ltx_video/ltxv.py | 3 -
models/ltx_video/ltxv_handler.py | 9 +
models/qwen/pipeline_qwenimage.py | 65 ++-
models/qwen/qwen_handler.py | 20 +-
models/qwen/qwen_main.py | 10 +-
models/wan/any2video.py | 41 +-
models/wan/df_handler.py | 2 +-
models/wan/wan_handler.py | 72 ++-
preprocessing/matanyone/app.py | 69 ++-
requirements.txt | 2 +-
shared/gradio/gallery.py | 141 +++--
shared/utils/utils.py | 53 +-
wgp.py | 833 +++++++++++++++++-------------
20 files changed, 893 insertions(+), 567 deletions(-)
diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json
index 894591880..20b6bc4d3 100644
--- a/defaults/flux_dev_kontext.json
+++ b/defaults/flux_dev_kontext.json
@@ -7,8 +7,6 @@
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors"
],
- "image_outputs": true,
- "reference_image": true,
"flux-model": "flux-dev-kontext"
},
"prompt": "add a hat",
diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json
index 8b5dbb603..0cd7b828d 100644
--- a/defaults/flux_dev_uso.json
+++ b/defaults/flux_dev_uso.json
@@ -6,8 +6,6 @@
"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
"URLs": "flux",
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
- "image_outputs": true,
- "reference_image": true,
"flux-model": "flux-dev-uso"
},
"prompt": "the man is wearing a hat",
diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json
index 2b24c72da..79b8b245d 100644
--- a/defaults/qwen_image_edit_20B.json
+++ b/defaults/qwen_image_edit_20B.json
@@ -9,9 +9,7 @@
],
"attention": {
"<89": "sdpa"
- },
- "reference_image": true,
- "image_outputs": true
+ }
},
"prompt": "add a hat",
"resolution": "1280x720",
diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py
index b8b7b9b42..c468d5a9a 100644
--- a/models/flux/flux_handler.py
+++ b/models/flux/flux_handler.py
@@ -13,28 +13,41 @@ def query_model_def(base_model_type, model_def):
flux_schnell = flux_model == "flux-schnell"
flux_chroma = flux_model == "flux-chroma"
flux_uso = flux_model == "flux-dev-uso"
- model_def_output = {
+ flux_kontext = flux_model == "flux-dev-kontext"
+
+ extra_model_def = {
"image_outputs" : True,
"no_negative_prompt" : not flux_chroma,
}
if flux_chroma:
- model_def_output["guidance_max_phases"] = 1
+ extra_model_def["guidance_max_phases"] = 1
elif not flux_schnell:
- model_def_output["embedded_guidance"] = True
+ extra_model_def["embedded_guidance"] = True
if flux_uso :
- model_def_output["any_image_refs_relative_size"] = True
- model_def_output["no_background_removal"] = True
-
- model_def_output["image_ref_choices"] = {
+ extra_model_def["any_image_refs_relative_size"] = True
+ extra_model_def["no_background_removal"] = True
+ extra_model_def["image_ref_choices"] = {
"choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
("Up to two Images are Style Images", "KIJ")],
"default": "KI",
"letters_filter": "KIJ",
"label": "Reference Images / Style Images"
}
- model_def_output["lock_image_refs_ratios"] = True
+
+ if flux_kontext:
+ extra_model_def["image_ref_choices"] = {
+ "choices": [
+ ("None", ""),
+ ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
+ ("Conditional Images are People / Objects", "I"),
+ ],
+ "letters_filter": "KI",
+ }
+
- return model_def_output
+ extra_model_def["lock_image_refs_ratios"] = True
+
+ return extra_model_def
@staticmethod
def query_supported_types():
@@ -122,10 +135,12 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
def update_default_settings(base_model_type, model_def, ui_defaults):
flux_model = model_def.get("flux-model", "flux-dev")
flux_uso = flux_model == "flux-dev-uso"
+ flux_kontext = flux_model == "flux-dev-kontext"
ui_defaults.update({
"embedded_guidance": 2.5,
- })
- if model_def.get("reference_image", False):
+ })
+
+ if flux_kontext or flux_uso:
ui_defaults.update({
"video_prompt_type": "KI",
})
diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py
index 9bb8e7379..4d7c67d41 100644
--- a/models/flux/flux_main.py
+++ b/models/flux/flux_main.py
@@ -24,44 +24,6 @@
from PIL import Image
-def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1):
- target_height_ref1 = int(target_height_ref1 // 64 * 64)
- target_width_ref1 = int(target_width_ref1 // 64 * 64)
- h, w = image.shape[-2:]
- if h < target_height_ref1 or w < target_width_ref1:
- # 计算长宽比
- aspect_ratio = w / h
- if h < target_height_ref1:
- new_h = target_height_ref1
- new_w = new_h * aspect_ratio
- if new_w < target_width_ref1:
- new_w = target_width_ref1
- new_h = new_w / aspect_ratio
- else:
- new_w = target_width_ref1
- new_h = new_w / aspect_ratio
- if new_h < target_height_ref1:
- new_h = target_height_ref1
- new_w = new_h * aspect_ratio
- else:
- aspect_ratio = w / h
- tgt_aspect_ratio = target_width_ref1 / target_height_ref1
- if aspect_ratio > tgt_aspect_ratio:
- new_h = target_height_ref1
- new_w = new_h * aspect_ratio
- else:
- new_w = target_width_ref1
- new_h = new_w / aspect_ratio
- # 使用 TVF.resize 进行图像缩放
- image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w)))
- # 计算中心裁剪的参数
- top = (image.shape[-2] - target_height_ref1) // 2
- left = (image.shape[-1] - target_width_ref1) // 2
- # 使用 TVF.crop 进行中心裁剪
- image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1)
- return image
-
-
def stitch_images(img1, img2):
# Resize img2 to match img1's height
width1, height1 = img1.size
@@ -171,8 +133,6 @@ def generate(
device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso']
image_stiching = not self.name in ['flux-dev-uso'] #and False
- # image_refs_relative_size = 100
- crop = False
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0:
@@ -186,36 +146,15 @@ def generate(
if image_stiching:
# image stiching method
stiched = input_ref_images[0]
- if "K" in video_prompt_type :
- w, h = input_ref_images[0].size
- height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
- # actual rescale will happen in prepare_kontext
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
else:
- first_ref = 0
- if "K" in video_prompt_type:
- # image latents tiling method
- w, h = input_ref_images[0].size
- if crop :
- img = convert_image_to_tensor(input_ref_images[0])
- img = resize_and_centercrop_image(img, height, width)
- input_ref_images[0] = convert_tensor_to_image(img)
- else:
- height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
- input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS)
- first_ref = 1
-
- for i in range(first_ref,len(input_ref_images)):
+ # latents stiching with resize
+ for i in range(len(input_ref_images)):
w, h = input_ref_images[i].size
- if crop:
- img = convert_image_to_tensor(input_ref_images[i])
- img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100))
- input_ref_images[i] = convert_tensor_to_image(img)
- else:
- image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
- input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
+ image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
+ input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
else:
input_ref_images = None
diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py
index a38a7bd0d..181a9a774 100644
--- a/models/hyvideo/hunyuan.py
+++ b/models/hyvideo/hunyuan.py
@@ -861,11 +861,6 @@ def generate(
freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx)
else:
if self.avatar:
- w, h = input_ref_images.size
- target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
- if target_width != w or target_height != h:
- input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS)
-
concat_dict = {'mode': 'timecat', 'bias': -1}
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
else:
diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py
index 9cbaea7d1..487e76d99 100644
--- a/models/hyvideo/hunyuan_handler.py
+++ b/models/hyvideo/hunyuan_handler.py
@@ -51,6 +51,23 @@ def query_model_def(base_model_type, model_def):
extra_model_def["tea_cache"] = True
extra_model_def["mag_cache"] = True
+ if base_model_type in ["hunyuan_custom_edit"]:
+ extra_model_def["guide_preprocessing"] = {
+ "selection": ["MV", "PMV"],
+ }
+
+ extra_model_def["mask_preprocessing"] = {
+ "selection": ["A", "NA"],
+ "default" : "NA"
+ }
+
+ if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]:
+ extra_model_def["image_ref_choices"] = {
+ "choices": [("Reference Image", "I")],
+ "letters_filter":"I",
+ "visible": False,
+ }
+
if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
@@ -141,6 +158,10 @@ def load_model(model_filename, model_type = None, base_model_type = None, model
return hunyuan_model, pipe
+ @staticmethod
+ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
+ pass
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults["embedded_guidance_scale"]= 6.0
diff --git a/models/ltx_video/ltxv.py b/models/ltx_video/ltxv.py
index e71ac4fce..db143fc6d 100644
--- a/models/ltx_video/ltxv.py
+++ b/models/ltx_video/ltxv.py
@@ -300,9 +300,6 @@ def generate(
prefix_size, height, width = input_video.shape[-3:]
else:
if image_start != None:
- frame_width, frame_height = image_start.size
- if fit_into_canvas != None:
- height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32)
conditioning_media_paths.append(image_start.unsqueeze(1))
conditioning_start_frames.append(0)
conditioning_control_frames.append(False)
diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py
index c89b69a43..e44c98399 100644
--- a/models/ltx_video/ltxv_handler.py
+++ b/models/ltx_video/ltxv_handler.py
@@ -26,6 +26,15 @@ def query_model_def(base_model_type, model_def):
extra_model_def["sliding_window"] = True
extra_model_def["image_prompt_types_allowed"] = "TSEV"
+ extra_model_def["guide_preprocessing"] = {
+ "selection": ["", "PV", "DV", "EV", "V"],
+ "labels" : { "V": "Use LTXV raw format"}
+ }
+
+ extra_model_def["mask_preprocessing"] = {
+ "selection": ["", "A", "NA", "XA", "XNA"],
+ }
+
return extra_model_def
@staticmethod
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index 07bdbd404..905b86493 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -28,7 +28,7 @@
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from diffusers import FlowMatchEulerDiscreteScheduler
from PIL import Image
-from shared.utils.utils import calculate_new_dimensions
+from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image
XLA_AVAILABLE = False
@@ -563,6 +563,8 @@ def __call__(
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
image = None,
+ image_mask = None,
+ denoising_strength = 0,
callback=None,
pipeline=None,
loras_slists=None,
@@ -694,14 +696,33 @@ def __call__(
image_width = image_width // multiple_of * multiple_of
image_height = image_height // multiple_of * multiple_of
ref_height, ref_width = 1568, 672
- if height * width < ref_height * ref_width: ref_height , ref_width = height , width
- if image_height * image_width > ref_height * ref_width:
- image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
- image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
+ if image_mask is None:
+ if height * width < ref_height * ref_width: ref_height , ref_width = height , width
+ if image_height * image_width > ref_height * ref_width:
+ image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
+ if (image_width,image_height) != image.size:
+ image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
+ image_mask_latents = None
+ else:
+ # _, image_width, image_height = min(
+ # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
+ # )
+ image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
+ # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
+ height, width = image_height, image_width
+ image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
+ image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
+ image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
+ convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
+ image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
+
prompt_image = image
- image = self.image_processor.preprocess(image, image_height, image_width)
- image = image.unsqueeze(2)
+ if image.size != (image_width, image_height):
+ image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
+
+ image.save("nnn.png")
+ image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
@@ -744,6 +765,8 @@ def __call__(
generator,
latents,
)
+ original_image_latents = None if image_latents is None else image_latents.clone()
+
if image is not None:
img_shapes = [
[
@@ -788,6 +811,15 @@ def __call__(
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
+ morph = False
+ if image_mask_latents is not None and denoising_strength <= 1.:
+ first_step = int(len(timesteps) * (1. - denoising_strength))
+ if not morph:
+ latent_noise_factor = timesteps[first_step]/1000
+ latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
+ timesteps = timesteps[first_step:]
+ self.scheduler.timesteps = timesteps
+ self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
# 6. Denoising loop
self.scheduler.set_begin_index(0)
@@ -797,10 +829,15 @@ def __call__(
update_loras_slists(self.transformer, loras_slists, updated_num_steps)
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
+
for i, t in enumerate(timesteps):
if self.interrupt:
continue
+ if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
+ latent_noise_factor = t/1000
+ latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
+
self._current_timestep = t
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
@@ -865,6 +902,12 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+ if image_mask_latents is not None:
+ next_t = timesteps[i+1] if i List[Any]:
def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
# Mirror the selected index into state and the gallery (server-side selected_index)
+
+ st = get_state(state)
+ last_time = st.get("last_time", None)
+ if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery)
+ # print(f"ignored:{time.time()}, real {st['selected']}")
+ return gr.update(selected_index=st["selected"]), st
+
idx = None
if evt is not None and hasattr(evt, "index"):
ix = evt.index
@@ -220,17 +232,28 @@ def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) :
idx = ix[0] * max(1, int(self.columns)) + ix[1]
else:
idx = ix[0]
- st = get_state(state)
n = len(get_list(gallery))
sel = idx if (idx is not None and 0 <= idx < n) else None
+ # print(f"image selected evt index:{sel}/{evt.selected}")
st["selected"] = sel
- # return gr.update(selected_index=sel), st
- # return gr.update(), st
- return st
+ return gr.update(), st
+
+ def _on_upload(self, value: List[Any], state: Dict[str, Any]) :
+ # Fires when users upload via the Gallery itself.
+ # items_filtered = self._filter_items_by_mode(list(value or []))
+ items_filtered = list(value or [])
+ st = get_state(state)
+ new_items = self._paths_from_payload(items_filtered)
+ st["items"] = new_items
+ new_sel = len(new_items) - 1
+ st["selected"] = new_sel
+ record_last_action(st,"add")
+ return gr.update(selected_index=new_sel), st
def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
# Fires when users add/drag/drop/delete via the Gallery itself.
- items_filtered = self._filter_items_by_mode(list(value or []))
+ # items_filtered = self._filter_items_by_mode(list(value or []))
+ items_filtered = list(value or [])
st = get_state(state)
st["items"] = items_filtered
# Keep selection if still valid, else default to last
@@ -240,10 +263,9 @@ def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) :
else:
new_sel = old_sel
st["selected"] = new_sel
- # return gr.update(value=items_filtered, selected_index=new_sel), st
- # return gr.update(value=items_filtered), st
-
- return gr.update(), st
+ st["last_action"] ="gallery_change"
+ # print(f"gallery change: set sel {new_sel}")
+ return gr.update(selected_index=new_sel), st
def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
"""
@@ -252,7 +274,8 @@ def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery):
and re-selects the last inserted item.
"""
# New items (respect image/video mode)
- new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
+ # new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload))
+ new_items = self._paths_from_payload(files_payload)
st = get_state(state)
cur: List[Any] = get_list(gallery)
@@ -298,30 +321,6 @@ def key_of(it: Any) -> Optional[str]:
if k is not None:
seen_new.add(k)
- # Remove any existing occurrences of the incoming items from current list,
- # BUT keep the currently selected item even if it's also in incoming.
- cur_clean: List[Any] = []
- # sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None
- # for idx, it in enumerate(cur):
- # k = key_of(it)
- # if it is sel_item:
- # cur_clean.append(it)
- # continue
- # if k is not None and k in seen_new:
- # continue # drop duplicate; we'll reinsert at the target spot
- # cur_clean.append(it)
-
- # # Compute insertion position: right AFTER the (possibly shifted) selected item
- # if sel_item is not None:
- # # find sel_item's new index in cur_clean
- # try:
- # pos_sel = cur_clean.index(sel_item)
- # except ValueError:
- # # Shouldn't happen, but fall back to end
- # pos_sel = len(cur_clean) - 1
- # insert_pos = pos_sel + 1
- # else:
- # insert_pos = len(cur_clean) # no selection -> append at end
insert_pos = min(sel, len(cur) -1)
cur_clean = cur
# Build final list and selection
@@ -330,6 +329,8 @@ def key_of(it: Any) -> Optional[str]:
st["items"] = merged
st["selected"] = new_sel
+ record_last_action(st,"add")
+ # print(f"gallery add: set sel {new_sel}")
return gr.update(value=merged, selected_index=new_sel), st
def _on_remove(self, state: Dict[str, Any], gallery) :
@@ -342,8 +343,9 @@ def _on_remove(self, state: Dict[str, Any], gallery) :
return gr.update(value=[], selected_index=None), st
new_sel = min(sel, len(items) - 1)
st["items"] = items; st["selected"] = new_sel
- # return gr.update(value=items, selected_index=new_sel), st
- return gr.update(value=items), st
+ record_last_action(st,"remove")
+ # print(f"gallery del: new sel {new_sel}")
+ return gr.update(value=items, selected_index=new_sel), st
def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None)
@@ -354,11 +356,15 @@ def _on_move(self, delta: int, state: Dict[str, Any], gallery) :
return gr.update(value=items, selected_index=sel), st
items[sel], items[j] = items[j], items[sel]
st["items"] = items; st["selected"] = j
+ record_last_action(st,"move")
+ # print(f"gallery move: set sel {j}")
return gr.update(value=items, selected_index=j), st
def _on_clear(self, state: Dict[str, Any]) :
st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode}
- return gr.update(value=[], selected_index=0), st
+ record_last_action(st,"clear")
+ # print(f"Clear all")
+ return gr.update(value=[], selected_index=None), st
def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
st = get_state(state); st["single"] = bool(to_single)
@@ -382,30 +388,38 @@ def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) :
def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False):
if parent is not None:
with parent:
- col = self._build_ui()
+ col = self._build_ui(update_form)
else:
- col = self._build_ui()
+ col = self._build_ui(update_form)
if not update_form:
self._wire_events()
return col
- def _build_ui(self) -> gr.Column:
+ def _build_ui(self, update = False) -> gr.Column:
with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col:
self.container = col
self.state = gr.State(dict(self._initial_state))
- self.gallery = gr.Gallery(
- label=self.label,
- value=self._initial_state["items"],
- height=self.height,
- columns=self.columns,
- show_label=self.show_label,
- preview= True,
- # type="pil",
- file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
- selected_index=self._initial_state["selected"], # server-side selection
- )
+ if update:
+ self.gallery = gr.update(
+ value=self._initial_state["items"],
+ selected_index=self._initial_state["selected"], # server-side selection
+ label=self.label,
+ show_label=self.show_label,
+ )
+ else:
+ self.gallery = gr.Gallery(
+ value=self._initial_state["items"],
+ label=self.label,
+ height=self.height,
+ columns=self.columns,
+ show_label=self.show_label,
+ preview= True,
+ # type="pil", # very slow
+ file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS),
+ selected_index=self._initial_state["selected"], # server-side selection
+ )
# One-line controls
exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None
@@ -418,10 +432,10 @@ def _build_ui(self) -> gr.Column:
size="sm",
min_width=1,
)
- self.btn_remove = gr.Button("Remove", size="sm", min_width=1)
+ self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1)
self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1)
self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1)
- self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
+ self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1)
return col
@@ -430,14 +444,24 @@ def _wire_events(self):
self.gallery.select(
self._on_select,
inputs=[self.state, self.gallery],
- outputs=[self.state],
+ outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
+ )
+
+ # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
+ self.gallery.upload(
+ self._on_upload,
+ inputs=[self.gallery, self.state],
+ outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
)
# Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.)
- self.gallery.change(
+ self.gallery.upload(
self._on_gallery_change,
inputs=[self.gallery, self.state],
outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
)
# Add via UploadButton
@@ -445,6 +469,7 @@ def _wire_events(self):
self._on_add,
inputs=[self.upload_btn, self.state, self.gallery],
outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
)
# Remove selected
@@ -452,6 +477,7 @@ def _wire_events(self):
self._on_remove,
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
)
# Reorder using selected index, keep same item selected
@@ -459,11 +485,13 @@ def _wire_events(self):
lambda st, gallery: self._on_move(-1, st, gallery),
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
)
self.btn_right.click(
lambda st, gallery: self._on_move(+1, st, gallery),
inputs=[self.state, self.gallery],
outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
)
# Clear all
@@ -471,6 +499,7 @@ def _wire_events(self):
self._on_clear,
inputs=[self.state],
outputs=[self.gallery, self.state],
+ trigger_mode="always_last",
)
# ---------------- public API ----------------
diff --git a/shared/utils/utils.py b/shared/utils/utils.py
index 7ddf1eb6e..5dbd0af17 100644
--- a/shared/utils/utils.py
+++ b/shared/utils/utils.py
@@ -19,6 +19,7 @@
import subprocess
import json
from functools import lru_cache
+os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
from PIL import Image
@@ -207,30 +208,62 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims
if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width
return height, width, margin_top, margin_left
-def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
- if fit_into_canvas == None:
+def rescale_and_crop(img, w, h):
+ ow, oh = img.size
+ target_ratio = w / h
+ orig_ratio = ow / oh
+
+ if orig_ratio > target_ratio:
+ # Crop width first
+ nw = int(oh * target_ratio)
+ img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh))
+ else:
+ # Crop height first
+ nh = int(ow / target_ratio)
+ img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2))
+
+ return img.resize((w, h), Image.LANCZOS)
+
+def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16):
+ if fit_into_canvas == None or fit_into_canvas == 2:
# return image_height, image_width
return canvas_height, canvas_width
- if fit_into_canvas:
+ if fit_into_canvas == 1:
scale1 = min(canvas_height / image_height, canvas_width / image_width)
scale2 = min(canvas_width / image_height, canvas_height / image_width)
scale = max(scale1, scale2)
- else:
+ else: #0 or #2 (crop)
scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2)
new_height = round( image_height * scale / block_size) * block_size
new_width = round( image_width * scale / block_size) * block_size
return new_height, new_width
-def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ):
+def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16):
+ if fit_crop:
+ image = rescale_and_crop(image, canvas_width, canvas_height)
+ new_width, new_height = image.size
+ else:
+ image_width, image_height = image.size
+ new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size )
+ image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ return image, new_height, new_width
+
+def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
if rm_background:
session = new_session()
output_list =[]
for i, img in enumerate(img_list):
width, height = img.size
-
- if fit_into_canvas:
+ if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
+ if outpainting_dims is not None:
+ resized_image =img
+ elif img.size != (budget_width, budget_height):
+ resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
+ else:
+ resized_image =img
+ elif fit_into_canvas == 1:
white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255
scale = min(budget_height / height, budget_width / width)
new_height = int(height * scale)
@@ -242,10 +275,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
resized_image = Image.fromarray(white_canvas)
else:
scale = (budget_height * budget_width / (height * width))**(1/2)
- new_height = int( round(height * scale / 16) * 16)
- new_width = int( round(width * scale / 16) * 16)
+ new_height = int( round(height * scale / block_size) * block_size)
+ new_width = int( round(width * scale / block_size) * block_size)
resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
- if rm_background and not (ignore_first and i == 0) :
+ if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
diff --git a/wgp.py b/wgp.py
index 396e273e2..45cb5cfec 100644
--- a/wgp.py
+++ b/wgp.py
@@ -1,4 +1,5 @@
import os
+os.environ["GRADIO_LANG"] = "en"
# # os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding
# os.environ["TORCH_LOGS"]= "recompiles"
import torch._logging as tlog
@@ -21,7 +22,7 @@
import importlib
from shared.utils import notification_sound
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
-from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, get_video_frame
+from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
from shared.utils.audio_video import save_image_metadata, read_image_metadata
from shared.match_archi import match_nvidia_architecture
@@ -61,7 +62,7 @@
target_mmgp_version = "3.6.0"
WanGP_version = "8.4"
-settings_version = 2.31
+settings_version = 2.33
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -220,7 +221,7 @@ def process_prompt_and_add_tasks(state, model_choice):
return get_queue_table(queue)
model_def = get_model_def(model_type)
model_handler = get_model_handler(model_type)
- image_outputs = inputs["image_mode"] == 1
+ image_outputs = inputs["image_mode"] > 0
any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False)
model_type = get_base_model_type(model_type)
inputs["model_filename"] = model_filename
@@ -370,7 +371,7 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("Mag Cache maximum number of steps is 50")
return
- if image_mode == 1:
+ if image_mode > 0:
audio_prompt_type = ""
if "B" in audio_prompt_type or "X" in audio_prompt_type:
@@ -477,7 +478,8 @@ def process_prompt_and_add_tasks(state, model_choice):
image_mask = None
if "G" in video_prompt_type:
- gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ")
+ if image_mode == 0:
+ gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ")
else:
denoising_strength = 1.0
if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]:
@@ -1334,9 +1336,11 @@ def update_queue_data(queue):
data = get_queue_table(queue)
if len(data) == 0:
- return gr.DataFrame(visible=False)
+ return gr.DataFrame(value=[], max_height=1)
+ elif len(data) == 1:
+ return gr.DataFrame(value=data, max_height= 83)
else:
- return gr.DataFrame(value=data, visible= True)
+ return gr.DataFrame(value=data, max_height= 1000)
def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom"
@@ -1371,6 +1375,12 @@ def _parse_args():
help="save proprocessed audio track with extract speakers for debugging or editing"
)
+ parser.add_argument(
+ "--debug-gen-form",
+ action="store_true",
+ help="View form generation / refresh time"
+ )
+
parser.add_argument(
"--vram-safety-coefficient",
type=float,
@@ -2070,8 +2080,6 @@ def fix_settings(model_type, ui_defaults):
if image_prompt_type != None :
if not isinstance(image_prompt_type, str):
image_prompt_type = "S" if image_prompt_type == 0 else "SE"
- # if model_type == "flf2v_720p" and not "E" in image_prompt_type:
- # image_prompt_type = "SE"
if settings_version <= 2:
image_prompt_type = image_prompt_type.replace("G","")
ui_defaults["image_prompt_type"] = image_prompt_type
@@ -2091,10 +2099,7 @@ def fix_settings(model_type, ui_defaults):
video_prompt_type = ui_defaults.get("video_prompt_type", "")
- any_reference_image = model_def.get("reference_image", False)
- if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"] or any_reference_image:
- if not "I" in video_prompt_type: # workaround for settings corruption
- video_prompt_type += "I"
+
if base_model_type in ["hunyuan"]:
video_prompt_type = video_prompt_type.replace("I", "")
@@ -2133,10 +2138,27 @@ def fix_settings(model_type, ui_defaults):
del ui_defaults["tea_cache_start_step_perc"]
ui_defaults["skip_steps_start_step_perc"] = tea_cache_start_step_perc
+ image_prompt_type = ui_defaults.get("image_prompt_type", "")
+ if len(image_prompt_type) > 0:
+ image_prompt_types_allowed = model_def.get("image_prompt_types_allowed","")
+ image_prompt_type = filter_letters(image_prompt_type, image_prompt_types_allowed)
+ ui_defaults["image_prompt_type"] = image_prompt_type
+
+ video_prompt_type = ui_defaults.get("video_prompt_type", "")
+ image_ref_choices_list = model_def.get("image_ref_choices", {}).get("choices", [])
+ if len(image_ref_choices_list)==0:
+ video_prompt_type = del_in_sequence(video_prompt_type, "IK")
+ else:
+ first_choice = image_ref_choices_list[0][1]
+ if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I"
+ if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K"
+ ui_defaults["video_prompt_type"] = video_prompt_type
+
model_handler = get_model_handler(base_model_type)
if hasattr(model_handler, "fix_settings"):
model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults)
+
def get_default_settings(model_type):
def get_default_prompt(i2v):
if i2v:
@@ -3323,8 +3345,8 @@ def check(src, cond):
if not all_letters(src, pos): return False
if neg is not None and any_letters(src, neg): return False
return True
- image_outputs = configs.get("image_mode",0) == 1
- map_video_prompt = {"V" : "Control Video", ("VA", "U") : "Mask Video", "I" : "Reference Images"}
+ image_outputs = configs.get("image_mode",0) > 0
+ map_video_prompt = {"V" : "Control Image" if image_outputs else "Control Video", ("VA", "U") : "Mask Image" if image_outputs else "Mask Video", "I" : "Reference Images"}
map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"}
map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"}
video_other_prompts = [ v for s,v in map_image_prompt.items() if all_letters(video_image_prompt_type,s)] \
@@ -3571,9 +3593,27 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis
end_time = time.time()
# print(f"duration:{end_time-start_time:.1f}")
- return results
+ return results
-def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
+def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, block_size= 16, expand_scale = 2):
+ frame_width, frame_height = input_image.size
+
+ if fit_canvas != None:
+ height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
+ input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS)
+ if input_mask is not None:
+ input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS)
+
+ if expand_scale != 0 and input_mask is not None:
+ kernel_size = abs(expand_scale)
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
+ op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
+ input_mask = np.array(input_mask)
+ input_mask = op_expand(input_mask, kernel, iterations=3)
+ input_mask = Image.fromarray(input_mask)
+ return input_image, input_mask
+
+def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
def mask_to_xyxy_box(mask):
@@ -3615,6 +3655,9 @@ def mask_to_xyxy_box(mask):
if len(video) == 0 or any_mask and len(mask_video) == 0:
return None, None
+ if fit_crop and outpainting_dims != None:
+ fit_crop = False
+ fit_canvas = 0 if fit_canvas is not None else None
frame_height, frame_width, _ = video[0].shape
@@ -3629,7 +3672,7 @@ def mask_to_xyxy_box(mask):
if outpainting_dims != None:
final_height, final_width = height, width
- height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8)
+ height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
if any_mask:
num_frames = min(len(video), len(mask_video))
@@ -3646,14 +3689,20 @@ def mask_to_xyxy_box(mask):
# for frame_idx in range(num_frames):
def prep_prephase(frame_idx):
frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy()
- frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS)
+ if fit_crop:
+ frame = rescale_and_crop(frame, width, height)
+ else:
+ frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS)
frame = np.array(frame)
if any_mask:
if any_identity_mask:
mask = np.full( (height, width, 3), 0, dtype= np.uint8)
else:
mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy()
- mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS)
+ if fit_crop:
+ mask = rescale_and_crop(mask, width, height)
+ else:
+ mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS)
mask = np.array(mask)
if len(mask.shape) == 3 and mask.shape[2] == 3:
@@ -3750,14 +3799,14 @@ def prep_prephase(frame_idx):
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
-def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, target_fps = 16, block_size = 16):
+def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps)
if len(frames_list) == 0:
return None
- if fit_canvas == None:
+ if fit_canvas == None or fit_crop:
new_height = height
new_width = width
else:
@@ -3775,7 +3824,10 @@ def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_can
processed_frames_list = []
for frame in frames_list:
frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8))
- frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
+ if fit_crop:
+ frame = rescale_and_crop(frame, new_width, new_height)
+ else:
+ frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
processed_frames_list.append(frame)
np_frames = [np.array(frame) for frame in processed_frames_list]
@@ -4115,9 +4167,10 @@ def process_prompt_enhancer(prompt_enhancer, original_prompts, image_start, ori
prompt_images = []
if "I" in prompt_enhancer:
if image_start != None:
- prompt_images.append(image_start)
+ prompt_images += image_start
if original_image_refs != None:
- prompt_images += original_image_refs[:1]
+ prompt_images += original_image_refs[:1]
+ prompt_images = [Image.open(img) if isinstance(img,str) else img for img in prompt_images]
if len(original_prompts) == 0 and not "T" in prompt_enhancer:
return None
else:
@@ -4223,7 +4276,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri
original_image_refs = inputs["image_refs"]
if original_image_refs is not None:
original_image_refs = [ convert_image(tup[0]) for tup in original_image_refs ]
- is_image = inputs["image_mode"] == 1
+ is_image = inputs["image_mode"] > 0
seed = inputs["seed"]
seed = set_seed(seed)
enhanced_prompts = []
@@ -4367,7 +4420,7 @@ def remove_temp_filenames(temp_filenames_list):
model_def = get_model_def(model_type)
- is_image = image_mode == 1
+ is_image = image_mode > 0
if is_image:
if min_frames_if_references >= 1000:
video_length = min_frames_if_references - 1000
@@ -4377,19 +4430,22 @@ def remove_temp_filenames(temp_filenames_list):
batch_size = 1
temp_filenames_list = []
- if image_guide is not None and isinstance(image_guide, Image.Image):
- video_guide = convert_image_to_video(image_guide)
- temp_filenames_list.append(video_guide)
- image_guide = None
+ convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False)
+ if convert_image_guide_to_video:
+ if image_guide is not None and isinstance(image_guide, Image.Image):
+ video_guide = convert_image_to_video(image_guide)
+ temp_filenames_list.append(video_guide)
+ image_guide = None
- if image_mask is not None and isinstance(image_mask, Image.Image):
- video_mask = convert_image_to_video(image_mask)
- temp_filenames_list.append(video_mask)
- image_mask = None
+ if image_mask is not None and isinstance(image_mask, Image.Image):
+ video_mask = convert_image_to_video(image_mask)
+ temp_filenames_list.append(video_mask)
+ image_mask = None
+ if model_def.get("no_background_removal", False): remove_background_images_ref = 0
+
base_model_type = get_base_model_type(model_type)
model_family = get_model_family(base_model_type)
- fit_canvas = server_config.get("fit_canvas", 0)
model_handler = get_model_handler(base_model_type)
block_size = model_handler.get_vae_block_size(base_model_type) if hasattr(model_handler, "get_vae_block_size") else 16
@@ -4415,7 +4471,7 @@ def remove_temp_filenames(temp_filenames_list):
return
width, height = resolution.split("x")
- width, height = int(width), int(height)
+ width, height = int(width) // block_size * block_size, int(height) // block_size * block_size
default_image_size = (height, width)
if slg_switch == 0:
@@ -4530,39 +4586,25 @@ def remove_temp_filenames(temp_filenames_list):
original_image_refs = image_refs
# image_refs = None
# nb_frames_positions= 0
- frames_to_inject = []
- any_background_ref = False
- outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
# Output Video Ratio Priorities:
# Source Video or Start Image > Control Video > Image Ref (background or positioned frames only) > UI Width, Height
# Image Ref (non background and non positioned frames) are boxed in a white canvas in order to keep their own width/height ratio
-
- if image_refs is not None and len(image_refs) > 0:
+ frames_to_inject = []
+ if image_refs is not None:
frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions is not None and len(frames_positions)> 0 else []
frames_positions_list = frames_positions_list[:len(image_refs)]
nb_frames_positions = len(frames_positions_list)
- if nb_frames_positions > 0:
- frames_to_inject = [None] * (max(frames_positions_list) + 1)
- for i, pos in enumerate(frames_positions_list):
- frames_to_inject[pos] = image_refs[i]
- if video_guide == None and video_source == None and not "L" in image_prompt_type and (nb_frames_positions > 0 or "K" in video_prompt_type) :
- from shared.utils.utils import get_outpainting_full_area_dimensions
- w, h = image_refs[0].size
- if outpainting_dims != None:
- h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
- default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
- fit_canvas = None
- # if there is a source video and a background image ref, the height/width ratio will need to be processed later by the code for the model (we dont know the source video dimensions at this point)
- if len(image_refs) > nb_frames_positions:
- any_background_ref = "K" in video_prompt_type
- if remove_background_images_ref > 0:
- send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
- os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg")
- from shared.utils.utils import resize_and_remove_background
- # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
- image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or model_def.get("lock_image_refs_ratios", False)) ) # no fit for vace ref images as it is done later
- update_task_thumbnails(task, locals())
- send_cmd("output")
+ any_background_ref = 0
+ if "K" in video_prompt_type:
+ any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1
+
+ outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")]
+ fit_canvas = server_config.get("fit_canvas", 0)
+ fit_crop = fit_canvas == 2
+ if fit_crop and outpainting_dims is not None:
+ fit_crop = False
+ fit_canvas = 0
+
joint_pass = boost ==1 #and profile != 1 and profile != 3
skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type)
@@ -4632,6 +4674,9 @@ def remove_temp_filenames(temp_filenames_list):
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
current_video_length = min(current_video_length, length)
+ if image_guide is not None:
+ image_guide, image_mask = preprocess_image_with_mask(image_guide, image_mask, height, width, fit_canvas = None, block_size= block_size, expand_scale = mask_expand)
+
seed = set_seed(seed)
torch.set_grad_enabled(False)
@@ -4723,22 +4768,22 @@ def remove_temp_filenames(temp_filenames_list):
return_latent_slice = None
if reuse_frames > 0:
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
- refresh_preview = {"image_guide" : None, "image_mask" : None}
+ refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
src_ref_images = image_refs
image_start_tensor = image_end_tensor = None
if window_no == 1 and (video_source is not None or image_start is not None):
if image_start is not None:
- new_height, new_width = calculate_new_dimensions(height, width, image_start.height, image_start.width, sample_fit_canvas, block_size = block_size)
- image_start_tensor = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ image_start_tensor, new_height, new_width = calculate_dimensions_and_resize_image(image_start, height, width, sample_fit_canvas, fit_crop, block_size = block_size)
+ if fit_crop: refresh_preview["image_start"] = image_start_tensor
image_start_tensor = convert_image_to_tensor(image_start_tensor)
pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1)
else:
- if "L" in image_prompt_type:
- refresh_preview["video_source"] = get_video_frame(video_source, 0)
- prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size )
+ prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size )
prefix_video = prefix_video.permute(3, 0, 1, 2)
prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w
+ if fit_crop or "L" in image_prompt_type: refresh_preview["video_source"] = convert_tensor_to_image(prefix_video, 0)
+
new_height, new_width = prefix_video.shape[-2:]
pre_video_guide = prefix_video[:, -reuse_frames:]
pre_video_frame = convert_tensor_to_image(prefix_video[:, -1])
@@ -4752,10 +4797,11 @@ def remove_temp_filenames(temp_filenames_list):
image_end_list= image_end if isinstance(image_end, list) else [image_end]
if len(image_end_list) >= window_no:
new_height, new_width = image_size
- image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ image_end_tensor, _, _ = calculate_dimensions_and_resize_image(image_end_list[window_no-1], new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size)
+ # image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ refresh_preview["image_end"] = image_end_tensor
image_end_tensor = convert_image_to_tensor(image_end_tensor)
image_end_list= None
-
window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count)
guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames)
alignment_shift = source_video_frames_count if reset_control_aligment else 0
@@ -4768,7 +4814,8 @@ def remove_temp_filenames(temp_filenames_list):
from models.wan.multitalk.multitalk import get_window_audio_embeddings
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
-
+ if vace:
+ video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
if video_guide is not None:
keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
@@ -4776,12 +4823,44 @@ def remove_temp_filenames(temp_filenames_list):
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
- if ltxv:
+ if vace:
+ context_scale = [ control_net_weight]
+ if "V" in video_prompt_type:
+ process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
+ preprocess_type, preprocess_type2 = "raw", None
+ for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSLCMU")):
+ if process_num == 0:
+ preprocess_type = process_map_video_guide.get(process_letter, "raw")
+ else:
+ preprocess_type2 = process_map_video_guide.get(process_letter, None)
+ status_info = "Extracting " + processes_names[preprocess_type]
+ extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
+ if len(extra_process_list) == 1:
+ status_info += " and " + processes_names[extra_process_list[0]]
+ elif len(extra_process_list) == 2:
+ status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
+ if preprocess_type2 is not None:
+ context_scale = [ control_net_weight /2, control_net_weight2 /2]
+ send_cmd("progress", [0, get_latest_status(state, status_info)])
+ video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 )
+ if preprocess_type2 != None:
+ video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
+
+ if video_guide_processed != None:
+ if sample_fit_canvas != None:
+ image_size = video_guide_processed.shape[-3: -1]
+ sample_fit_canvas = None
+ refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy())
+ if video_guide_processed2 != None:
+ refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
+ if video_mask_processed != None:
+ refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
+ elif ltxv:
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
status_info = "Extracting " + processes_names[preprocess_type]
send_cmd("progress", [0, get_latest_status(state, status_info)])
# start one frame ealier to facilitate latents merging later
- src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
+ src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
if src_video != None:
src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
@@ -4798,15 +4877,14 @@ def remove_temp_filenames(temp_filenames_list):
progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
send_cmd("progress", progress_args)
- src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
+ src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
if src_mask != None:
refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
elif "R" in video_prompt_type: # sparse video to video
src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
- new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size)
- src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ src_image, _, _ = calculate_dimensions_and_resize_image(src_image, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size)
refresh_preview["video_guide"] = src_image
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
if sample_fit_canvas != None:
@@ -4814,7 +4892,7 @@ def remove_temp_filenames(temp_filenames_list):
sample_fit_canvas = None
else: # video to video
- video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps)
+ video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps)
if video_guide_processed is None:
src_video = pre_video_guide
else:
@@ -4824,42 +4902,54 @@ def remove_temp_filenames(temp_filenames_list):
src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
if pre_video_guide != None:
src_video = torch.cat( [pre_video_guide, src_video], dim=1)
-
- if vace :
- image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
- context_scale = [ control_net_weight]
- video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
- if "V" in video_prompt_type:
- process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
- preprocess_type, preprocess_type2 = "raw", None
- for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSLCMU")):
- if process_num == 0:
- preprocess_type = process_map_video_guide.get(process_letter, "raw")
+ elif image_guide is not None:
+ image_guide, new_height, new_width = calculate_dimensions_and_resize_image(image_guide, height, width, sample_fit_canvas, fit_crop, block_size = block_size)
+ image_size = (new_height, new_width)
+ refresh_preview["image_guide"] = image_guide
+ sample_fit_canvas = None
+ if image_mask is not None:
+ image_mask, _, _ = calculate_dimensions_and_resize_image(image_mask, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size)
+ refresh_preview["image_mask"] = image_mask
+
+ if window_no == 1 and image_refs is not None and len(image_refs) > 0:
+ if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
+ from shared.utils.utils import get_outpainting_full_area_dimensions
+ w, h = image_refs[0].size
+ if outpainting_dims != None:
+ h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims)
+ image_size = calculate_new_dimensions(height, width, h, w, fit_canvas)
+ sample_fit_canvas = None
+ if repeat_no == 1:
+ if fit_crop:
+ if any_background_ref == 2:
+ end_ref_position = len(image_refs)
+ elif any_background_ref == 1:
+ end_ref_position = nb_frames_positions + 1
else:
- preprocess_type2 = process_map_video_guide.get(process_letter, None)
- status_info = "Extracting " + processes_names[preprocess_type]
- extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
- if len(extra_process_list) == 1:
- status_info += " and " + processes_names[extra_process_list[0]]
- elif len(extra_process_list) == 2:
- status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
- if preprocess_type2 is not None:
- context_scale = [ control_net_weight /2, control_net_weight2 /2]
- send_cmd("progress", [0, get_latest_status(state, status_info)])
- video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 )
- if preprocess_type2 != None:
- video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
+ end_ref_position = nb_frames_positions
+ for i, img in enumerate(image_refs[:end_ref_position]):
+ image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0])
+ refresh_preview["image_refs"] = image_refs
+
+ if len(image_refs) > nb_frames_positions:
+ if remove_background_images_ref > 0:
+ send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
+ # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
+ image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
+ remove_background_images_ref > 0, any_background_ref,
+ fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
+ block_size=block_size,
+ outpainting_dims =outpainting_dims )
+ refresh_preview["image_refs"] = image_refs
+
+ if nb_frames_positions > 0:
+ frames_to_inject = [None] * (max(frames_positions_list) + 1)
+ for i, pos in enumerate(frames_positions_list):
+ frames_to_inject[pos] = image_refs[i]
- if video_guide_processed != None:
- if sample_fit_canvas != None:
- image_size = video_guide_processed.shape[-3: -1]
- sample_fit_canvas = None
- refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy())
- if video_guide_processed2 != None:
- refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
- if video_mask_processed != None:
- refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
+ if vace :
frames_to_inject_parsed = frames_to_inject[aligned_guide_start_frame: aligned_guide_end_frame]
+ image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
[video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2],
@@ -4868,7 +4958,6 @@ def remove_temp_filenames(temp_filenames_list):
keep_video_guide_frames=keep_frames_parsed,
start_frame = aligned_guide_start_frame,
pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide],
- fit_into_canvas = sample_fit_canvas,
inject_frames= frames_to_inject_parsed,
outpainting_dims = outpainting_dims,
any_background_ref = any_background_ref
@@ -4931,9 +5020,9 @@ def set_header_text(txt):
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
frame_num= (current_video_length // latent_size)* latent_size + 1,
batch_size = batch_size,
- height = height,
- width = width,
- fit_into_canvas = fit_canvas == 1,
+ height = image_size[0],
+ width = image_size[1],
+ fit_into_canvas = fit_canvas,
shift=flow_shift,
sample_solver=sample_solver,
sampling_steps=num_inference_steps,
@@ -4990,6 +5079,8 @@ def set_header_text(txt):
pre_video_frame = pre_video_frame,
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
image_refs_relative_size = image_refs_relative_size,
+ image_guide= image_guide,
+ image_mask= image_mask,
)
except Exception as e:
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
@@ -5931,7 +6022,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if target == "settings":
return inputs
- image_outputs = inputs.get("image_mode",0) == 1
+ image_outputs = inputs.get("image_mode",0) > 0
pop=[]
if "force_fps" in inputs and len(inputs["force_fps"])== 0:
@@ -5947,13 +6038,13 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
pop += ["MMAudio_setting", "MMAudio_prompt", "MMAudio_neg_prompt"]
video_prompt_type = inputs["video_prompt_type"]
- if not base_model_type in ["t2v"]:
+ if not "G" in video_prompt_type:
pop += ["denoising_strength"]
if not (server_config.get("enhancer_enabled", 0) > 0 and server_config.get("enhancer_mode", 0) == 0):
pop += ["prompt_enhancer"]
- if not recammaster and not diffusion_forcing and not flux:
+ if model_def.get("model_modes", None) is None:
pop += ["model_mode"]
if not vace and not phantom and not hunyuan_video_custom:
@@ -6075,6 +6166,18 @@ def image_to_ref_image_set(state, input_file_list, choice, target, target_name):
gr.Info(f"Selected Image was copied to {target_name}")
return file_list[choice]
+def image_to_ref_image_guide(state, input_file_list, choice):
+ file_list, file_settings_list = get_file_list(state, input_file_list)
+ if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update(), gr.update()
+ ui_settings = get_current_model_settings(state)
+ gr.Info(f"Selected Image was copied to Control Image")
+ new_image = file_list[choice]
+ if ui_settings["image_mode"]==2:
+ return new_image, new_image
+ else:
+ return new_image, None
+
+
def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation):
gen = get_gen_info(state)
@@ -6142,11 +6245,11 @@ def eject_video_from_gallery(state, input_file_list, choice):
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
def has_video_file_extension(filename):
- extension = os.path.splitext(filename)[-1]
+ extension = os.path.splitext(filename)[-1].lower()
return extension in [".mp4"]
def has_image_file_extension(filename):
- extension = os.path.splitext(filename)[-1]
+ extension = os.path.splitext(filename)[-1].lower()
return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
gen = get_gen_info(state)
@@ -6308,7 +6411,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw
return configs, any_image_or_video
def record_image_mode_tab(state, evt:gr.SelectData):
- state["image_mode_tab"] = 0 if evt.index ==0 else 1
+ state["image_mode_tab"] = evt.index
def switch_image_mode(state):
image_mode = state.get("image_mode_tab", 0)
@@ -6316,7 +6419,18 @@ def switch_image_mode(state):
ui_defaults = get_model_settings(state, model_type)
ui_defaults["image_mode"] = image_mode
-
+ video_prompt_type = ui_defaults.get("video_prompt_type", "")
+ model_def = get_model_def( model_type)
+ inpaint_support = model_def.get("inpaint_support", False)
+ if inpaint_support:
+ if image_mode == 1:
+ video_prompt_type = del_in_sequence(video_prompt_type, "VAG")
+ video_prompt_type = add_to_sequence(video_prompt_type, "KI")
+ elif image_mode == 2:
+ video_prompt_type = add_to_sequence(video_prompt_type, "VAG")
+ video_prompt_type = del_in_sequence(video_prompt_type, "KI")
+ ui_defaults["video_prompt_type"] = video_prompt_type
+
return str(time.time())
def load_settings_from_file(state, file_path):
@@ -6349,6 +6463,7 @@ def load_settings_from_file(state, file_path):
def save_inputs(
target,
+ image_mask_guide,
lset_name,
image_mode,
prompt,
@@ -6434,13 +6549,18 @@ def save_inputs(
state,
):
-
- # if state.get("validate_success",0) != 1:
- # return
+
model_filename = state["model_filename"]
model_type = state["model_type"]
+ if image_mask_guide is not None and image_mode == 2:
+ if "background" in image_mask_guide:
+ image_guide = image_mask_guide["background"]
+ if "layers" in image_mask_guide and len(image_mask_guide["layers"])>0:
+ image_mask = image_mask_guide["layers"][0]
+ image_mask_guide = None
inputs = get_function_arguments(save_inputs, locals())
inputs.pop("target")
+ inputs.pop("image_mask_guide")
cleaned_inputs = prepare_inputs_dict(target, inputs)
if target == "settings":
defaults_filename = get_settings_file_name(model_type)
@@ -6544,11 +6664,16 @@ def change_model(state, model_choice):
return header
-def fill_inputs(state):
+def get_current_model_settings(state):
model_type = state["model_type"]
- ui_defaults = get_model_settings(state, model_type)
+ ui_defaults = get_model_settings(state, model_type)
if ui_defaults == None:
ui_defaults = get_default_settings(model_type)
+ set_model_settings(state, model_type, ui_defaults)
+ return ui_defaults
+
+def fill_inputs(state):
+ ui_defaults = get_current_model_settings(state)
return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
@@ -6623,7 +6748,9 @@ def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_
image_prompt_type = del_in_sequence(image_prompt_type, "VLTS")
image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio)
any_video_source = len(filter_letters(image_prompt_type, "VL"))>0
- end_visible = any_letters(image_prompt_type, "SVL")
+ model_def = get_model_def(state["model_type"])
+ image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "")
+ end_visible = "E" in image_prompt_types_allowed and any_letters(image_prompt_type, "SVL")
return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible)
def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox):
@@ -6654,7 +6781,7 @@ def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_
visible= "A" in video_prompt_type
model_type = state["model_type"]
model_def = get_model_def(model_type)
- image_outputs = image_mode == 1
+ image_outputs = image_mode > 0
return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible )
def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_type_video_guide):
@@ -6663,20 +6790,22 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
return video_prompt_type
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode):
- video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMGUV")
+ video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
model_type = state["model_type"]
base_model_type = get_base_model_type(model_type)
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
model_def = get_model_def(model_type)
- image_outputs = image_mode == 1
+ image_outputs = image_mode > 0
vace= test_vace_module(model_type)
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible)
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt):
- video_prompt_type = del_in_sequence(video_prompt_type, "RGUVQKI")
+ model_def = get_model_def(state["model_type"])
+ guide_custom_choices = model_def.get("guide_custom_choices",{})
+ video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter",""))
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
control_video_visible = "V" in video_prompt_type
ref_images_visible = "I" in video_prompt_type
@@ -6711,7 +6840,7 @@ def get_image_end_label(multi_prompts_gen_type):
return "Images as ending points for new Videos in the Generation Queue" if multi_prompts_gen_type == 0 else "Images as ending points for each new Window of the same Video Generation"
def refresh_prompt_labels(multi_prompts_gen_type, image_mode):
- prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode == 1)
+ prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode > 0)
return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label), gr.update(label=get_image_end_label(multi_prompts_gen_type))
def show_preview_column_modal(state, column_no):
@@ -7032,7 +7161,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
diffusion_forcing = "diffusion_forcing" in model_filename
ltxv = "ltxv" in model_filename
lock_inference_steps = model_def.get("lock_inference_steps", False)
- model_reference_image = model_def.get("reference_image", False)
any_tea_cache = model_def.get("tea_cache", False)
any_mag_cache = model_def.get("mag_cache", False)
recammaster = base_model_type in ["recam_1.3B"]
@@ -7075,18 +7203,22 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
if not v2i_switch_supported and not image_outputs:
image_mode_value = 0
else:
- image_outputs = image_mode_value == 1
+ image_outputs = image_mode_value > 0
+ inpaint_support = model_def.get("inpaint_support", False)
image_mode = gr.Number(value =image_mode_value, visible = False)
-
- with gr.Tabs(visible = v2i_switch_supported, selected= "t2i" if image_mode_value == 1 else "t2v" ) as image_mode_tabs:
- with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab"):
+ image_mode_tab_selected= "t2i" if image_mode_value == 1 else ("inpaint" if image_mode_value == 2 else "t2v")
+ with gr.Tabs(visible = v2i_switch_supported or inpaint_support, selected= image_mode_tab_selected ) as image_mode_tabs:
+ with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab", visible = v2i_switch_supported) as tab_t2v:
pass
with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"):
pass
+ with gr.Tab("Image Inpainting", id = "inpaint", elem_classes="compact_tab", visible=inpaint_support) as tab_inpaint:
+ pass
image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "")
model_mode_choices = model_def.get("model_modes", None)
with gr.Column(visible= len(image_prompt_types_allowed)> 0 or model_mode_choices is not None) as image_prompt_column:
+ # Video Continue / Start Frame / End Frame
image_prompt_type_value= ui_defaults.get("image_prompt_type","")
image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
image_prompt_type_choices = []
@@ -7123,192 +7255,167 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=ui_defaults.get("model_mode", model_mode_choices["default"]), label=model_mode_choices["label"], visible=True)
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
- with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column:
+ any_control_video = any_control_image = False
+ guide_preprocessing = model_def.get("guide_preprocessing", None)
+ mask_preprocessing = model_def.get("mask_preprocessing", None)
+ guide_custom_choices = model_def.get("guide_custom_choices", None)
+ image_ref_choices = model_def.get("image_ref_choices", None)
+
+ # with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or (flux or qwen ) and model_reference_image and image_mode_value >=1) as video_prompt_column:
+ with gr.Column(visible= guide_preprocessing is not None or mask_preprocessing is not None or guide_custom_choices is not None or image_ref_choices is not None) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
- any_control_video = True
- any_control_image = image_outputs
- with gr.Row():
- if t2v:
- video_prompt_type_video_guide = gr.Dropdown(
- choices=[
- ("Use Text Prompt Only", ""),
- ("Image to Image guided by Text Prompt" if image_outputs else "Video to Video guided by Text Prompt", "GUV"),
- ],
- value=filter_letters(video_prompt_type_value, "GUV"),
- label="Video to Video", scale = 2, show_label= False, visible= True
- )
- elif vace :
+ with gr.Row(visible = image_mode_value!=2) as guide_selection_row:
+ # Control Video Preprocessing
+ if guide_preprocessing is None:
+ video_prompt_type_video_guide = gr.Dropdown(choices=[("","")], value="", label="Control Video", scale = 2, visible= False, show_label= True, )
+ else:
pose_label = "Pose" if image_outputs else "Motion"
+ guide_preprocessing_labels_all = {
+ "": "No Control Video",
+ "UV": "Keep Control Video Unchanged",
+ "PV": f"Transfer Human {pose_label}",
+ "PMV": f"Transfer Human {pose_label}",
+ "DV": "Transfer Depth",
+ "EV": "Transfer Canny Edges",
+ "SV": "Transfer Shapes",
+ "LV": "Transfer Flow",
+ "CV": "Recolorize",
+ "MV": "Perform Inpainting",
+ "V": "Use Vace raw format",
+ "PDV": f"Transfer Human {pose_label} & Depth",
+ "PSV": f"Transfer Human {pose_label} & Shapes",
+ "PLV": f"Transfer Human {pose_label} & Flow" ,
+ "DSV": "Transfer Depth & Shapes",
+ "DLV": "Transfer Depth & Flow",
+ "SLV": "Transfer Shapes & Flow",
+ }
+ guide_preprocessing_choices = []
+ guide_preprocessing_labels = guide_preprocessing.get("labels", {})
+ for process_type in guide_preprocessing["selection"]:
+ process_label = guide_preprocessing_labels.get(process_type, None)
+ process_label = guide_preprocessing_labels_all.get(process_type,process_type) if process_label is None else process_label
+ if image_outputs: process_label = process_label.replace("Video", "Image")
+ guide_preprocessing_choices.append( (process_label, process_type) )
+
+ video_prompt_type_video_guide_label = guide_preprocessing.get("label", "Control Video Process")
+ if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
video_prompt_type_video_guide = gr.Dropdown(
- choices=[
- ("No Control Image" if image_outputs else "No Control Video", ""),
- ("Keep Control Image Unchanged" if image_outputs else "Keep Control Video Unchanged", "UV"),
- (f"Transfer Human {pose_label}" , "PV"),
- ("Transfer Depth", "DV"),
- ("Transfer Shapes", "SV"),
- ("Transfer Flow", "LV"),
- ("Recolorize", "CV"),
- ("Perform Inpainting", "MV"),
- ("Use Vace raw format", "V"),
- (f"Transfer Human {pose_label} & Depth", "PDV"),
- (f"Transfer Human {pose_label} & Shapes", "PSV"),
- (f"Transfer Human {pose_label} & Flow", "PLV"),
- ("Transfer Depth & Shapes", "DSV"),
- ("Transfer Depth & Flow", "DLV"),
- ("Transfer Shapes & Flow", "SLV"),
- ],
- value=filter_letters(video_prompt_type_value, "PDSLCMGUV"),
- label="Control Image Process" if image_outputs else "Control Video Process", scale = 2, visible= True, show_label= True,
- )
- elif ltxv:
- video_prompt_type_video_guide = gr.Dropdown(
- choices=[
- ("No Control Video", ""),
- ("Transfer Human Motion", "PV"),
- ("Transfer Depth", "DV"),
- ("Transfer Canny Edges", "EV"),
- ("Use LTXV raw format", "V"),
- ],
- value=filter_letters(video_prompt_type_value, "PDEV"),
- label="Control Video Process", scale = 2, visible= True, show_label= True,
+ guide_preprocessing_choices,
+ value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ),
+ label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
)
+ any_control_video = True
+ any_control_image = image_outputs
- elif hunyuan_video_custom_edit:
- video_prompt_type_video_guide = gr.Dropdown(
- choices=[
- ("Inpaint Control Image" if image_outputs else "Inpaint Control Video", "MV"),
- ("Transfer Human Motion", "PMV"),
- ],
- value=filter_letters(video_prompt_type_value, "PDSLCMUV"),
- label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True,
- )
- elif recammaster:
- video_prompt_type_video_guide = gr.Dropdown(value="UV", choices = [("Control Video","UV")], visible=False)
+ # Alternate Control Video Preprocessing / Options
+ if guide_custom_choices is None:
+ video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 2 )
else:
- any_control_video = False
- any_control_image = False
- video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False)
-
- if infinitetalk:
+ video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process")
+ if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image")
+ video_prompt_type_video_guide_alt_choices = [(label.replace("Video", "Image") if image_outputs else label, value) for label,value in guide_custom_choices["choices"] ]
video_prompt_type_video_guide_alt = gr.Dropdown(
- choices=[
- ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
- ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
- ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
- ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
- ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
- ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
- ],
- value=filter_letters(video_prompt_type_value, "RGUVQKI"),
- label="Video to Video", scale = 3, visible= True, show_label= False,
- )
- any_control_video = any_control_image = True
- else:
- video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False)
-
- # video_prompt_video_guide_trigger = gr.Text(visible=False, value="")
- if t2v:
- video_prompt_type_video_mask = gr.Dropdown(value = "", choices = [""], visible = False)
- elif hunyuan_video_custom_edit:
- video_prompt_type_video_mask = gr.Dropdown(
- choices=[
- ("Masked Area", "A"),
- ("Non Masked Area", "NA"),
- ],
- value= filter_letters(video_prompt_type_value, "NA"),
- visible= "V" in video_prompt_type_value,
- label="Area Processed", scale = 2, show_label= True,
- )
- elif ltxv:
- video_prompt_type_video_mask = gr.Dropdown(
- choices=[
- ("Whole Frame", ""),
- ("Masked Area", "A"),
- ("Non Masked Area", "NA"),
- ("Masked Area, rest Inpainted", "XA"),
- ("Non Masked Area, rest Inpainted", "XNA"),
- ],
- value= filter_letters(video_prompt_type_value, "XNA"),
- visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value,
- label="Area Processed", scale = 2, show_label= True,
+ choices= video_prompt_type_video_guide_alt_choices,
+ value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ),
+ visible = guide_custom_choices.get("visible", True),
+ label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = 2
)
+ any_control_video = True
+ any_control_image = image_outputs
+
+ # Control Mask Preprocessing
+ if mask_preprocessing is None:
+ video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 2, visible= False, show_label= True, )
else:
+ mask_preprocessing_labels_all = {
+ "": "Whole Frame",
+ "A": "Masked Area",
+ "NA": "Non Masked Area",
+ "XA": "Masked Area, rest Inpainted",
+ "XNA": "Non Masked Area, rest Inpainted",
+ "YA": "Masked Area, rest Depth",
+ "YNA": "Non Masked Area, rest Depth",
+ "WA": "Masked Area, rest Shapes",
+ "WNA": "Non Masked Area, rest Shapes",
+ "ZA": "Masked Area, rest Flow",
+ "ZNA": "Non Masked Area, rest Flow"
+ }
+
+ mask_preprocessing_choices = []
+ mask_preprocessing_labels = guide_preprocessing.get("labels", {})
+ for process_type in mask_preprocessing["selection"]:
+ process_label = mask_preprocessing_labels.get(process_type, None)
+ process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label
+ mask_preprocessing_choices.append( (process_label, process_type) )
+
+ video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed")
video_prompt_type_video_mask = gr.Dropdown(
- choices=[
- ("Whole Frame", ""),
- ("Masked Area", "A"),
- ("Non Masked Area", "NA"),
- ("Masked Area, rest Inpainted", "XA"),
- ("Non Masked Area, rest Inpainted", "XNA"),
- ("Masked Area, rest Depth", "YA"),
- ("Non Masked Area, rest Depth", "YNA"),
- ("Masked Area, rest Shapes", "WA"),
- ("Non Masked Area, rest Shapes", "WNA"),
- ("Masked Area, rest Flow", "ZA"),
- ("Non Masked Area, rest Flow", "ZNA"),
- ],
- value= filter_letters(video_prompt_type_value, "XYZWNA"),
- visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom and not ltxv,
- label="Area Processed", scale = 2, show_label= True,
- )
- image_ref_choices = model_def.get("image_ref_choices", None)
- if image_ref_choices is not None:
- video_prompt_type_image_refs = gr.Dropdown(
- choices= image_ref_choices["choices"],
- value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
- visible = True,
- label=image_ref_choices["label"], show_label= True, scale = 2
- )
- elif t2v:
- video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False)
- elif vace:
- video_prompt_type_image_refs = gr.Dropdown(
- choices=[
- ("None", ""),
- ("Inject only People / Objects", "I"),
- ("Inject Landscape and then People / Objects", "KI"),
- ("Inject Frames and then People / Objects", "FI"),
- ],
- value=filter_letters(video_prompt_type_value, "KFI"),
- visible = True,
- label="Reference Images", show_label= True, scale = 2
- )
- elif standin: # and not vace
- video_prompt_type_image_refs = gr.Dropdown(
- choices=[
- ("No Reference Image", ""),
- ("Reference Image is a Person Face", "I"),
- ],
- value=filter_letters(video_prompt_type_value, "I"),
- visible = True,
- show_label=False,
- label="Reference Image", scale = 2
+ mask_preprocessing_choices,
+ value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")),
+ label= video_prompt_type_video_mask_label , scale = 2, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True),
+ show_label= True,
)
- elif (flux or qwen) and model_reference_image:
+
+ # Image Refs Selection
+ if image_ref_choices is None:
video_prompt_type_image_refs = gr.Dropdown(
- choices=[
- ("None", ""),
- ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"),
- ("Conditional Images are People / Objects", "I"),
- ],
- value=filter_letters(video_prompt_type_value, "KI"),
- visible = True,
- show_label=False,
- label="Reference Images Combination Method", scale = 2
+ # choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")],
+ choices=[ ("None", ""),],
+ value=filter_letters(video_prompt_type_value, ""),
+ visible = False,
+ label="Start / Reference Images", scale = 2
)
+ any_reference_image = False
else:
+ any_reference_image = True
video_prompt_type_image_refs = gr.Dropdown(
- choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")],
- value=filter_letters(video_prompt_type_value, "KI"),
- visible = False,
- label="Start / Reference Images", scale = 2
+ choices= image_ref_choices["choices"],
+ value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
+ visible = image_ref_choices.get("visible", True),
+ label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2
)
- image_guide = gr.Image(label= "Control Image", height = gallery_height, type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None))
+
+ image_guide = gr.Image(label= "Control Image", height = gallery_height, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None))
video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None))
+ if image_mode_value == 2 and inpaint_support:
+ image_guide_value = ui_defaults.get("image_guide", None)
+ image_mask_value = ui_defaults.get("image_mask", None)
+ if image_guide_value is None:
+ image_mask_guide_value = None
+ else:
+ def rgb_bw_to_rgba_mask(img, thresh=127):
+ a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha
+ out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent
+ out.putalpha(a) # white where alpha=255
+ return out
+
+ image_mask_value = rgb_bw_to_rgba_mask(image_mask_value)
+ image_mask_guide_value = { "background" : image_guide_value, "composite" : None, "layers": [image_mask_value] }
+
+ image_mask_guide = gr.ImageEditor(
+ label="Control Image to be Inpainted",
+ value = image_mask_guide_value,
+ type='pil',
+ sources=["upload", "webcam"],
+ image_mode='RGB',
+ layers=False,
+ brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
+ # fixed_canvas= True,
+ width=800,
+ height=800,
+ # transforms=None,
+ # interactive=True,
+ elem_id="img_editor",
+ visible= True
+ )
+ any_control_image = True
+ else:
+ image_mask_guide = gr.ImageEditor(value = None, visible = False, elem_id="img_editor")
- denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False)
+
+ denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label=f"Denoising Strength (the Lower the Closer to the Control {'Image' if image_outputs else 'Video'})", visible = "G" in video_prompt_type_value, show_reset_button= False)
keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False)
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
@@ -7325,11 +7432,10 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False)
video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False)
any_image_mask = image_outputs and vace
- image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None))
+ image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None))
video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("video_mask", None))
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
- any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image
image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "")
@@ -7424,10 +7530,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
visible= True, show_label= not on_demand_prompt_enhancer,
)
with gr.Row():
- if server_config.get("fit_canvas", 0) == 1:
- label = "Max Resolution (As it maybe less depending on video width / height ratio)"
+ fit_canvas = server_config.get("fit_canvas", 0)
+ if fit_canvas == 1:
+ label = "Outer Box Resolution (one dimension may be less to preserve video W/H ratio)"
+ elif fit_canvas == 2:
+ label = "Output Resolution (Input Images wil be Cropped if the W/H ratio is different)"
else:
- label = "Max Resolution (Pixels will be reallocated depending on the output width / height ratio)"
+ label = "Resolution Budget (Pixels will be reallocated to preserve Inputs W/H ratio)"
current_resolution_choice = ui_defaults.get("resolution","832x480") if update_form or last_resolution is None else last_resolution
model_resolutions = model_def.get("resolutions", None)
resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions)
@@ -7751,7 +7860,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
sliding_window_defaults = model_def.get("sliding_window_defaults", {})
sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
- sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)")
+ sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True)
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
@@ -7902,7 +8011,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
add_to_queue_trigger = gr.Text(visible = False)
with gr.Column(visible= False) as current_gen_column:
- with gr.Accordion("Preview", open=False) as queue_accordion:
+ with gr.Accordion("Preview", open=False):
preview = gr.Image(label="Preview", height=200, show_label= False)
preview_trigger = gr.Text(visible= False)
gen_info = gr.HTML(visible=False, min_height=1)
@@ -7947,8 +8056,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
- NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs,
- min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] + image_start_extra + image_end_extra + image_refs_extra # presets_column,
+ NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs,
+ min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn, tab_inpaint, tab_t2v] + image_start_extra + image_end_extra + image_refs_extra # presets_column,
if update_form:
locals_dict = locals()
gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs
@@ -7970,10 +8079,10 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" )
image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] )
# video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
- video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col])
- video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand])
- video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ])
- video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand])
+ video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden")
+ video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand], show_progress="hidden")
+ video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ], show_progress="hidden")
+ video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand], show_progress="hidden")
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden")
video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
@@ -7984,8 +8093,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then(
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
- gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab])
- preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview])
+ gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab], show_progress="hidden")
+ preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview], show_progress="hidden")
PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] )
def refresh_status_async(state, progress=gr.Progress()):
gen = get_gen_info(state)
@@ -8017,7 +8126,9 @@ def activate_status(state):
output_trigger.change(refresh_gallery,
inputs = [state],
- outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn])
+ outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn],
+ show_progress="hidden"
+ )
preview_column_no.input(show_preview_column_modal, inputs=[state, preview_column_no], outputs=[preview_column_no, modal_image_display, modal_container])
@@ -8033,7 +8144,8 @@ def activate_status(state):
gr.on( triggers=[video_info_extract_settings_btn.click, video_info_extract_image_settings_btn.click], fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
@@ -8042,7 +8154,8 @@ def activate_status(state):
prompt_enhancer_btn.click(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
@@ -8050,7 +8163,8 @@ def activate_status(state):
saveform_trigger.change(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
@@ -8065,14 +8179,14 @@ def activate_status(state):
video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] )
video_info_to_start_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] )
video_info_to_end_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] )
- video_info_to_image_guide_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_guide, gr.State("Control Image")], outputs = [image_guide] )
+ video_info_to_image_guide_btn.click(fn=image_to_ref_image_guide, inputs =[state, output, last_choice], outputs = [image_guide, image_mask_guide] )
video_info_to_image_mask_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_mask, gr.State("Image Mask")], outputs = [image_mask] )
video_info_to_reference_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] )
video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation], outputs = [mode, generate_trigger, add_to_queue_trigger ] )
video_info_remux_audio_btn.click(fn=remux_audio, inputs =[state, output, last_choice, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation, PP_custom_audio], outputs = [mode, generate_trigger, add_to_queue_trigger ] )
save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop])
delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ])
- confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then(
+ confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt], show_progress="hidden",).then(
fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None).then(
@@ -8087,7 +8201,8 @@ def activate_status(state):
lset_name.select(fn=update_lset_type, inputs=[state, lset_name], outputs=save_lset_prompt_drop)
export_settings_from_file_btn.click(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
@@ -8104,7 +8219,8 @@ def activate_status(state):
image_mode_tabs.select(fn=record_image_mode_tab, inputs=[state], outputs= None
).then(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
@@ -8112,7 +8228,8 @@ def activate_status(state):
settings_file.upload(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
@@ -8126,17 +8243,20 @@ def activate_status(state):
refresh_form_trigger.change(fn= fill_inputs,
inputs=[state],
- outputs=gen_inputs + extra_inputs
+ outputs=gen_inputs + extra_inputs,
+ show_progress= "full" if args.debug_gen_form else "hidden",
).then(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars],
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
)
model_family.input(fn=change_model_family, inputs=[state, model_family], outputs= [model_choice])
model_choice.change(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
@@ -8145,7 +8265,8 @@ def activate_status(state):
outputs= [header]
).then(fn= fill_inputs,
inputs=[state],
- outputs=gen_inputs + extra_inputs
+ outputs=gen_inputs + extra_inputs,
+ show_progress="full" if args.debug_gen_form else "hidden",
).then(fn= preload_model_when_switching,
inputs=[state],
outputs=[gen_status])
@@ -8154,13 +8275,15 @@ def activate_status(state):
generate_trigger.change(fn=validate_wizard_prompt,
inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
- outputs= queue_df
+ outputs= queue_df,
+ show_progress="hidden",
).then(fn=prepare_generate_video,
inputs= [state],
outputs= [generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row]
@@ -8170,10 +8293,12 @@ def activate_status(state):
).then(
fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
inputs=[state],
- outputs=[queue_accordion]
+ outputs=[queue_accordion],
+ show_progress="hidden",
).then(fn=process_tasks,
inputs= [state],
outputs= [preview_trigger, output_trigger],
+ show_progress="hidden",
).then(finalize_generation,
inputs= [state],
outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info]
@@ -8280,17 +8405,20 @@ def activate_status(state):
# gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt,
add_to_queue_trigger.change(fn=validate_wizard_prompt,
inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] ,
- outputs= [prompt]
+ outputs= [prompt],
+ show_progress="hidden",
).then(fn=save_inputs,
inputs =[target_state] + gen_inputs,
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
- outputs=queue_df
+ outputs=queue_df,
+ show_progress="hidden",
).then(
fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
inputs=[state],
- outputs=[queue_accordion]
+ outputs=[queue_accordion],
+ show_progress="hidden",
).then(
fn=update_status,
inputs = [state],
@@ -8302,8 +8430,8 @@ def activate_status(state):
outputs=[modal_container]
)
- return ( state, loras_choices, lset_name, resolution,
- video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger
+ return ( state, loras_choices, lset_name, resolution, refresh_form_trigger,
+ # video_guide, image_guide, video_mask, image_mask, image_refs,
)
@@ -8339,8 +8467,9 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice
fit_canvas_choice = gr.Dropdown(
choices=[
- ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0),
- ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1),
+ ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be Resized to match this pixels Budget, output video height or width may exceed the requested dimensions )", 0),
+ ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be Resized to fit into these dimensions, the output video may be smaller)", 1),
+ ("Dimensions correspond to the Output Width and Height (as the Prompt Image/Video will be Cropped to fit exactly these dimensions)", 2),
],
value= server_config.get("fit_canvas", 0),
label="Generated Video Dimensions when Prompt contains an Image or a Video",
@@ -9231,9 +9360,17 @@ def create_ui():
console.log('Events dispatched for column:', index);
}
};
-
console.log('sendColIndex function attached to window');
- }
+
+ // cancel wheel usage inside image editor
+ const hit = n => n?.id === "img_editor" || n?.classList?.contains("wheel-pass");
+ addEventListener("wheel", e => {
+ const path = e.composedPath?.() || (() => { let a=[],n=e.target; for(;n;n=n.parentNode||n.host) a.push(n); return a; })();
+ if (path.some(hit)) e.stopImmediatePropagation();
+ }, { capture: true, passive: true });
+
+ }
+
"""
if server_config.get("display_stats", 0) == 1:
from shared.utils.stats import SystemStatsApp
@@ -9264,13 +9401,13 @@ def create_ui():
stats_element = stats_app.get_gradio_element()
with gr.Row():
- ( state, loras_choices, lset_name, resolution,
- video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger
+ ( state, loras_choices, lset_name, resolution, refresh_form_trigger
+ # video_guide, image_guide, video_mask, image_mask, image_refs,
) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs)
with gr.Tab("Guides", id="info") as info_tab:
generate_info_tab()
with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator:
- matanyone_app.display(main_tabs, tab_state, server_config, video_guide, image_guide, video_mask, image_mask, image_refs)
+ matanyone_app.display(main_tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings) #, video_guide, image_guide, video_mask, image_mask, image_refs)
if not args.lock_config:
with gr.Tab("Downloads", id="downloads") as downloads_tab:
generate_download_tab(lset_name, loras_choices, state)
From e7c08d12c846ddf4817013e1b112b1ecf183ada3 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 10 Sep 2025 01:47:55 +0200
Subject: [PATCH 024/155] fixed unwanted discontinuity with at the end of first
sliding window with InfiniteTalk
---
models/wan/any2video.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index bb91dc66b..1e62c3511 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -539,7 +539,7 @@ def generate(self,
new_shot = "Q" in video_prompt_type
else:
if pre_video_frame is None:
- new_shot = True
+ new_shot = "Q" in video_prompt_type
else:
if input_ref_images is None:
input_ref_images, new_shot = [pre_video_frame], False
From 9fa267087b2dfdba651fd173325537f031edf91d Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 11 Sep 2025 21:23:05 +0200
Subject: [PATCH 025/155] Flux Festival
---
README.md | 9 +++-
defaults/flux_dev_umo.json | 24 ++++++++++
defaults/flux_dev_uso.json | 2 +-
defaults/flux_srpo.json | 15 ++++++
defaults/flux_srpo_uso.json | 17 +++++++
models/flux/flux_handler.py | 22 +++++++++
models/flux/flux_main.py | 77 +++++++++++++++++++++++++------
models/flux/model.py | 15 ++++++
models/flux/sampling.py | 57 +++++++++++++++++++++--
models/flux/util.py | 32 +++++++++++++
models/qwen/pipeline_qwenimage.py | 19 +++++---
models/qwen/qwen_handler.py | 1 +
models/wan/any2video.py | 2 +-
models/wan/wan_handler.py | 1 +
wgp.py | 62 ++++++++++++++++---------
15 files changed, 305 insertions(+), 50 deletions(-)
create mode 100644 defaults/flux_dev_umo.json
create mode 100644 defaults/flux_srpo.json
create mode 100644 defaults/flux_srpo_uso.json
diff --git a/README.md b/README.md
index fc3d76c84..d33b6dc4b 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
-### September 5 2025: WanGP v8.5 - Wanna be a Cropper or a Painter ?
+### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ?
I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof.
@@ -38,6 +38,13 @@ Doing more sophisticated thing Vace Image Editor works very well too: try Image
For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*"
+**update 8.55**: Flux Festival
+- **Inpainting Mode** also added for *Flux Kontext*
+- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available
+- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768
+
+Good luck with finding your way through all the Flux models names !
+
### September 5 2025: WanGP v8.4 - Take me to Outer Space
You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie.
diff --git a/defaults/flux_dev_umo.json b/defaults/flux_dev_umo.json
new file mode 100644
index 000000000..57164bb4c
--- /dev/null
+++ b/defaults/flux_dev_umo.json
@@ -0,0 +1,24 @@
+{
+ "model": {
+ "name": "Flux 1 Dev UMO 12B",
+ "architecture": "flux",
+ "description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.",
+ "URLs": "flux",
+ "flux-model": "flux-dev-umo",
+ "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"],
+ "resolutions": [ ["1024x1024 (1:1)", "1024x1024"],
+ ["768x1024 (3:4)", "768x1024"],
+ ["1024x768 (4:3)", "1024x768"],
+ ["512x1024 (1:2)", "512x1024"],
+ ["1024x512 (2:1)", "1024x512"],
+ ["768x768 (1:1)", "768x768"],
+ ["768x512 (3:2)", "768x512"],
+ ["512x768 (2:3)", "512x768"]]
+ },
+ "prompt": "the man is wearing a hat",
+ "embedded_guidance_scale": 4,
+ "resolution": "768x768",
+ "batch_size": 1
+}
+
+
\ No newline at end of file
diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json
index 0cd7b828d..806dd7e02 100644
--- a/defaults/flux_dev_uso.json
+++ b/defaults/flux_dev_uso.json
@@ -2,7 +2,7 @@
"model": {
"name": "Flux 1 Dev USO 12B",
"architecture": "flux",
- "description": "FLUX.1 Dev USO is a model specialized to Edit Images with a specialization in Style Transfers (up to two).",
+ "description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).",
"modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]],
"URLs": "flux",
"loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"],
diff --git a/defaults/flux_srpo.json b/defaults/flux_srpo.json
new file mode 100644
index 000000000..59f07c617
--- /dev/null
+++ b/defaults/flux_srpo.json
@@ -0,0 +1,15 @@
+{
+ "model": {
+ "name": "Flux 1 SRPO Dev 12B",
+ "architecture": "flux",
+ "description": "By fine-tuning the FLUX.1.dev model with optimized denoising and online reward adjustment, SRPO improves its human-evaluated realism and aesthetic quality by over 3x.",
+ "URLs": [
+ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_bf16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_quanto_bf16_int8.safetensors"
+ ],
+ "flux-model": "flux-dev"
+ },
+ "prompt": "draw a hat",
+ "resolution": "1024x1024",
+ "batch_size": 1
+}
\ No newline at end of file
diff --git a/defaults/flux_srpo_uso.json b/defaults/flux_srpo_uso.json
new file mode 100644
index 000000000..ddfe50d63
--- /dev/null
+++ b/defaults/flux_srpo_uso.json
@@ -0,0 +1,17 @@
+{
+ "model": {
+ "name": "Flux 1 SRPO USO 12B",
+ "architecture": "flux",
+ "description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process",
+ "modules": [ "flux_dev_uso"],
+ "URLs": "flux_srpo",
+ "loras": "flux_dev_uso",
+ "flux-model": "flux-dev-uso"
+ },
+ "prompt": "the man is wearing a hat",
+ "embedded_guidance_scale": 4,
+ "resolution": "1024x1024",
+ "batch_size": 1
+}
+
+
\ No newline at end of file
diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py
index c468d5a9a..808369ff6 100644
--- a/models/flux/flux_handler.py
+++ b/models/flux/flux_handler.py
@@ -13,6 +13,7 @@ def query_model_def(base_model_type, model_def):
flux_schnell = flux_model == "flux-schnell"
flux_chroma = flux_model == "flux-chroma"
flux_uso = flux_model == "flux-dev-uso"
+ flux_umo = flux_model == "flux-dev-umo"
flux_kontext = flux_model == "flux-dev-kontext"
extra_model_def = {
@@ -35,6 +36,7 @@ def query_model_def(base_model_type, model_def):
}
if flux_kontext:
+ extra_model_def["inpaint_support"] = True
extra_model_def["image_ref_choices"] = {
"choices": [
("None", ""),
@@ -43,6 +45,15 @@ def query_model_def(base_model_type, model_def):
],
"letters_filter": "KI",
}
+ extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape"
+ elif flux_umo:
+ extra_model_def["image_ref_choices"] = {
+ "choices": [
+ ("Conditional Images are People / Objects", "I"),
+ ],
+ "letters_filter": "I",
+ "visible": False
+ }
extra_model_def["lock_image_refs_ratios"] = True
@@ -131,10 +142,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
video_prompt_type = video_prompt_type.replace("I", "KI")
ui_defaults["video_prompt_type"] = video_prompt_type
+ if settings_version < 2.34:
+ ui_defaults["denoising_strength"] = 1.
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
flux_model = model_def.get("flux-model", "flux-dev")
flux_uso = flux_model == "flux-dev-uso"
+ flux_umo = flux_model == "flux-dev-umo"
flux_kontext = flux_model == "flux-dev-kontext"
ui_defaults.update({
"embedded_guidance": 2.5,
@@ -143,5 +158,12 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
if flux_kontext or flux_uso:
ui_defaults.update({
"video_prompt_type": "KI",
+ "denoising_strength": 1.,
+ })
+ elif flux_umo:
+ ui_defaults.update({
+ "video_prompt_type": "I",
+ "remove_background_images_ref": 0,
})
+
diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py
index 4d7c67d41..686371129 100644
--- a/models/flux/flux_main.py
+++ b/models/flux/flux_main.py
@@ -23,6 +23,35 @@
)
from PIL import Image
+def preprocess_ref(raw_image: Image.Image, long_size: int = 512):
+ # 获取原始图像的宽度和高度
+ image_w, image_h = raw_image.size
+
+ # 计算长边和短边
+ if image_w >= image_h:
+ new_w = long_size
+ new_h = int((long_size / image_w) * image_h)
+ else:
+ new_h = long_size
+ new_w = int((long_size / image_h) * image_w)
+
+ # 按新的宽高进行等比例缩放
+ raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS)
+ target_w = new_w // 16 * 16
+ target_h = new_h // 16 * 16
+
+ # 计算裁剪的起始坐标以实现中心裁剪
+ left = (new_w - target_w) // 2
+ top = (new_h - target_h) // 2
+ right = left + target_w
+ bottom = top + target_h
+
+ # 进行中心裁剪
+ raw_image = raw_image.crop((left, top, right, bottom))
+
+ # 转换为 RGB 模式
+ raw_image = raw_image.convert("RGB")
+ return raw_image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
@@ -67,7 +96,7 @@ def __init__(
# self.name= "flux-schnell"
source = model_def.get("source", None)
self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device)
-
+ self.model_def = model_def
self.vae = load_ae(self.name, device=torch_device)
siglip_processor = siglip_model = feature_embedder = None
@@ -109,10 +138,12 @@ def __init__(
def generate(
self,
seed: int | None = None,
- input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
+ input_prompt: str = "replace the logo with the text 'Black Forest Labs'",
n_prompt: str = None,
sampling_steps: int = 20,
input_ref_images = None,
+ image_guide= None,
+ image_mask= None,
width= 832,
height=480,
embedded_guidance_scale: float = 2.5,
@@ -123,7 +154,8 @@ def generate(
batch_size = 1,
video_prompt_type = "",
joint_pass = False,
- image_refs_relative_size = 100,
+ image_refs_relative_size = 100,
+ denoising_strength = 1.,
**bbargs
):
if self._interrupt:
@@ -132,8 +164,16 @@ def generate(
if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
device="cuda"
flux_dev_uso = self.name in ['flux-dev-uso']
- image_stiching = not self.name in ['flux-dev-uso'] #and False
+ flux_dev_umo = self.name in ['flux-dev-umo']
+ latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo']
+
+ lock_dimensions= False
+
input_ref_images = [] if input_ref_images is None else input_ref_images[:]
+ if flux_dev_umo:
+ ref_long_side = 512 if len(input_ref_images) <= 1 else 320
+ input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images]
+ lock_dimensions = True
ref_style_imgs = []
if "I" in video_prompt_type and len(input_ref_images) > 0:
if flux_dev_uso :
@@ -143,22 +183,26 @@ def generate(
elif len(input_ref_images) > 1 :
ref_style_imgs = input_ref_images[-1:]
input_ref_images = input_ref_images[:-1]
- if image_stiching:
+
+ if latent_stiching:
+ # latents stiching with resize
+ if not lock_dimensions :
+ for i in range(len(input_ref_images)):
+ w, h = input_ref_images[i].size
+ image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0)
+ input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
+ else:
# image stiching method
stiched = input_ref_images[0]
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
- else:
- # latents stiching with resize
- for i in range(len(input_ref_images)):
- w, h = input_ref_images[i].size
- image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas)
- input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
+ elif image_guide is not None:
+ input_ref_images = [image_guide]
else:
input_ref_images = None
- if flux_dev_uso :
+ if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
inp, height, width = prepare_multi_ip(
ae=self.vae,
img_cond_list=input_ref_images,
@@ -177,6 +221,7 @@ def generate(
bs=batch_size,
seed=seed,
device=device,
+ img_mask=image_mask,
)
inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt))
@@ -198,13 +243,19 @@ def unpack_latent(x):
return unpack(x.float(), height, width)
# denoise initial noise
- x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass)
+ x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength)
if x==None: return None
# decode latents to pixel space
x = unpack_latent(x)
with torch.autocast(device_type=device, dtype=torch.bfloat16):
x = self.vae.decode(x)
+ if image_mask is not None:
+ from shared.utils.utils import convert_image_to_tensor
+ img_msk_rebuilt = inp["img_msk_rebuilt"]
+ img= convert_image_to_tensor(image_guide)
+ x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
+
x = x.clamp(-1, 1)
x = x.transpose(0, 1)
return x
diff --git a/models/flux/model.py b/models/flux/model.py
index c4642d0bc..c5f7a24b4 100644
--- a/models/flux/model.py
+++ b/models/flux/model.py
@@ -190,6 +190,21 @@ def swap_scale_shift(weight):
v = swap_scale_shift(v)
k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1")
new_sd[k] = v
+ # elif not first_key.startswith("diffusion_model.") and not first_key.startswith("transformer."):
+ # for k,v in sd.items():
+ # if "double" in k:
+ # k = k.replace(".processor.proj_lora1.", ".img_attn.proj.lora_")
+ # k = k.replace(".processor.proj_lora2.", ".txt_attn.proj.lora_")
+ # k = k.replace(".processor.qkv_lora1.", ".img_attn.qkv.lora_")
+ # k = k.replace(".processor.qkv_lora2.", ".txt_attn.qkv.lora_")
+ # else:
+ # k = k.replace(".processor.qkv_lora.", ".linear1_qkv.lora_")
+ # k = k.replace(".processor.proj_lora.", ".linear2.lora_")
+
+ # k = "diffusion_model." + k
+ # new_sd[k] = v
+ # from mmgp import safetensors2
+ # safetensors2.torch_write_file(new_sd, "fff.safetensors")
else:
new_sd = sd
return new_sd
diff --git a/models/flux/sampling.py b/models/flux/sampling.py
index f43ae15e0..1b4813a5c 100644
--- a/models/flux/sampling.py
+++ b/models/flux/sampling.py
@@ -138,10 +138,12 @@ def prepare_kontext(
target_width: int | None = None,
target_height: int | None = None,
bs: int = 1,
-
+ img_mask = None,
) -> tuple[dict[str, Tensor], int, int]:
# load and encode the conditioning image
+ res_match_output = img_mask is not None
+
img_cond_seq = None
img_cond_seq_ids = None
if img_cond_list == None: img_cond_list = []
@@ -150,9 +152,11 @@ def prepare_kontext(
for cond_no, img_cond in enumerate(img_cond_list):
width, height = img_cond.size
aspect_ratio = width / height
-
- # Kontext is trained on specific resolutions, using one of them is recommended
- _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
+ if res_match_output:
+ width, height = target_width, target_height
+ else:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
width = 2 * int(width / 16)
height = 2 * int(height / 16)
@@ -193,6 +197,19 @@ def prepare_kontext(
"img_cond_seq": img_cond_seq,
"img_cond_seq_ids": img_cond_seq_ids,
}
+ if img_mask is not None:
+ from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image
+ # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
+ image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS))
+ image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
+ image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
+ convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
+ image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
+ return_dict.update({
+ "img_msk_latents": image_mask_latents,
+ "img_msk_rebuilt": image_mask_rebuilt,
+ })
+
img = get_noise(
bs,
target_height,
@@ -264,6 +281,9 @@ def denoise(
loras_slists=None,
unpack_latent = None,
joint_pass= False,
+ img_msk_latents = None,
+ img_msk_rebuilt = None,
+ denoising_strength = 1,
):
kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids}
@@ -271,6 +291,21 @@ def denoise(
if callback != None:
callback(-1, None, True)
+ original_image_latents = None if img_cond_seq is None else img_cond_seq.clone()
+
+ morph, first_step = False, 0
+ if img_msk_latents is not None:
+ randn = torch.randn_like(original_image_latents)
+ if denoising_strength < 1.:
+ first_step = int(len(timesteps) * (1. - denoising_strength))
+ if not morph:
+ latent_noise_factor = timesteps[first_step]
+ latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
+ img = latents.to(img)
+ latents = None
+ timesteps = timesteps[first_step:]
+
+
updated_num_steps= len(timesteps) -1
if callback != None:
from shared.utils.loras_mutipliers import update_loras_slists
@@ -280,10 +315,14 @@ def denoise(
# this is ignored for schnell
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
- offload.set_step_no_for_lora(model, i)
+ offload.set_step_no_for_lora(model, first_step + i)
if pipeline._interrupt:
return None
+ if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph:
+ latent_noise_factor = t_curr/1000
+ img = original_image_latents * (1.0 - latent_noise_factor) + img * latent_noise_factor
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
img_input = img
img_input_ids = img_ids
@@ -333,6 +372,14 @@ def denoise(
pred = neg_pred + real_guidance_scale * (pred - neg_pred)
img += (t_prev - t_curr) * pred
+
+ if img_msk_latents is not None:
+ latent_noise_factor = t_prev
+ # noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
+ noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
+ img = noisy_image * (1-img_msk_latents) + img_msk_latents * img
+ noisy_image = None
+
if callback is not None:
preview = unpack_latent(img).transpose(0,1)
callback(i, preview, False)
diff --git a/models/flux/util.py b/models/flux/util.py
index 0f96103de..af75f6269 100644
--- a/models/flux/util.py
+++ b/models/flux/util.py
@@ -640,6 +640,38 @@ class ModelSpec:
shift_factor=0.1159,
),
),
+ "flux-dev-umo": ModelSpec(
+ repo_id="",
+ repo_flow="",
+ repo_ae="ckpts/flux_vae.safetensors",
+ params=FluxParams(
+ in_channels=64,
+ out_channels=64,
+ vec_in_dim=768,
+ context_in_dim=4096,
+ hidden_size=3072,
+ mlp_ratio=4.0,
+ num_heads=24,
+ depth=19,
+ depth_single_blocks=38,
+ axes_dim=[16, 56, 56],
+ theta=10_000,
+ qkv_bias=True,
+ guidance_embed=True,
+ eso= True,
+ ),
+ ae_params=AutoEncoderParams(
+ resolution=256,
+ in_channels=3,
+ ch=128,
+ out_ch=3,
+ ch_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ z_channels=16,
+ scale_factor=0.3611,
+ shift_factor=0.1159,
+ ),
+ ),
}
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index 20838f595..0897ee419 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -714,14 +714,14 @@ def __call__(
image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
- convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
+ # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
prompt_image = image
if image.size != (image_width, image_height):
image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
- image.save("nnn.png")
+ # image.save("nnn.png")
image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2)
has_neg_prompt = negative_prompt is not None or (
@@ -811,12 +811,15 @@ def __call__(
negative_txt_seq_lens = (
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
- morph = False
- if image_mask_latents is not None and denoising_strength <= 1.:
- first_step = int(len(timesteps) * (1. - denoising_strength))
+ morph, first_step = False, 0
+ if image_mask_latents is not None:
+ randn = torch.randn_like(original_image_latents)
+ if denoising_strength < 1.:
+ first_step = int(len(timesteps) * (1. - denoising_strength))
if not morph:
latent_noise_factor = timesteps[first_step]/1000
- latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
+ # latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor
+ latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor
timesteps = timesteps[first_step:]
self.scheduler.timesteps = timesteps
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
@@ -831,6 +834,7 @@ def __call__(
for i, t in enumerate(timesteps):
+ offload.set_step_no_for_lora(self.transformer, first_step + i)
if self.interrupt:
continue
@@ -905,7 +909,8 @@ def __call__(
if image_mask_latents is not None:
next_t = timesteps[i+1] if i sliding_window_size:
+ if model_type in ["t2v"] and not "G" in video_prompt_type :
+ gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.")
+ return
full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1
extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation"
no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap)
gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated")
-
if "recam" in model_filename:
if video_guide == None:
gr.Info("You must provide a Control Video")
@@ -7019,28 +7020,38 @@ def categorize_resolution(resolution_str):
return group
return "1440p"
-def group_resolutions(resolutions, selected_resolution):
-
- grouped_resolutions = {}
- for resolution in resolutions:
- group = categorize_resolution(resolution[1])
- if group not in grouped_resolutions:
- grouped_resolutions[group] = []
- grouped_resolutions[group].append(resolution)
-
- available_groups = [group for group in group_thresholds if group in grouped_resolutions]
+def group_resolutions(model_def, resolutions, selected_resolution):
+
+ model_resolutions = model_def.get("resolutions", None)
+ if model_resolutions is not None:
+ selected_group ="Locked"
+ available_groups = [selected_group ]
+ selected_group_resolutions = model_resolutions
+ else:
+ grouped_resolutions = {}
+ for resolution in resolutions:
+ group = categorize_resolution(resolution[1])
+ if group not in grouped_resolutions:
+ grouped_resolutions[group] = []
+ grouped_resolutions[group].append(resolution)
+
+ available_groups = [group for group in group_thresholds if group in grouped_resolutions]
- selected_group = categorize_resolution(selected_resolution)
- selected_group_resolutions = grouped_resolutions.get(selected_group, [])
- available_groups.reverse()
+ selected_group = categorize_resolution(selected_resolution)
+ selected_group_resolutions = grouped_resolutions.get(selected_group, [])
+ available_groups.reverse()
return available_groups, selected_group_resolutions, selected_group
def change_resolution_group(state, selected_group):
model_type = state["model_type"]
model_def = get_model_def(model_type)
model_resolutions = model_def.get("resolutions", None)
- resolution_choices, _ = get_resolution_choices(None, model_resolutions)
- group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ]
+ resolution_choices, _ = get_resolution_choices(None, model_resolutions)
+ if model_resolutions is None:
+ group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ]
+ else:
+ last_resolution = group_resolution_choices[0][1]
+ return gr.update(choices= group_resolution_choices, value= last_resolution)
last_resolution_per_group = state["last_resolution_per_group"]
last_resolution = last_resolution_per_group.get(selected_group, "")
@@ -7051,6 +7062,11 @@ def change_resolution_group(state, selected_group):
def record_last_resolution(state, resolution):
+
+ model_type = state["model_type"]
+ model_def = get_model_def(model_type)
+ model_resolutions = model_def.get("resolutions", None)
+ if model_resolutions is not None: return
server_config["last_resolution_choice"] = resolution
selected_group = categorize_resolution(resolution)
last_resolution_per_group = state["last_resolution_per_group"]
@@ -7482,11 +7498,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
- no_background_removal = model_def.get("no_background_removal", False)
+ no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None
+ background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects")
+
remove_background_images_ref = gr.Dropdown(
choices=[
("Keep Backgrounds behind all Reference Images", 0),
- ("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else ("Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" if vace else "Remove Backgrounds behind People / Objects") , 1),
+ (background_removal_label, 1),
],
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
@@ -7578,7 +7596,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
current_resolution_choice = ui_defaults.get("resolution","832x480") if update_form or last_resolution is None else last_resolution
model_resolutions = model_def.get("resolutions", None)
resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions)
- available_groups, selected_group_resolutions, selected_group = group_resolutions(resolution_choices, current_resolution_choice)
+ available_groups, selected_group_resolutions, selected_group = group_resolutions(model_def,resolution_choices, current_resolution_choice)
resolution_group = gr.Dropdown(
choices = available_groups,
value= selected_group,
From 7bcd7246217954956536515d8b220f349cb5f288 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Mon, 15 Sep 2025 19:28:18 +0200
Subject: [PATCH 026/155] attack of the clones
---
README.md | 228 +---------
defaults/qwen_image_edit_20B.json | 1 +
defaults/vace_fun_14B_2_2.json | 24 ++
defaults/vace_fun_14B_cocktail_2_2.json | 28 ++
docs/CHANGELOG.md | 144 ++++++-
models/flux/flux_handler.py | 2 +-
models/flux/sampling.py | 2 +-
models/hyvideo/hunyuan_handler.py | 10 +-
models/ltx_video/ltxv.py | 4 +-
models/ltx_video/ltxv_handler.py | 2 +-
models/qwen/pipeline_qwenimage.py | 16 +-
models/qwen/qwen_handler.py | 27 +-
models/qwen/qwen_main.py | 14 +-
models/wan/any2video.py | 86 ++--
models/wan/df_handler.py | 2 +-
models/wan/multitalk/multitalk.py | 18 +-
models/wan/wan_handler.py | 9 +-
preprocessing/extract_vocals.py | 69 ++++
preprocessing/speakers_separator.py | 12 +-
requirements.txt | 5 +-
wgp.py | 526 +++++++++++++++---------
21 files changed, 755 insertions(+), 474 deletions(-)
create mode 100644 defaults/vace_fun_14B_2_2.json
create mode 100644 defaults/vace_fun_14B_cocktail_2_2.json
create mode 100644 preprocessing/extract_vocals.py
diff --git a/README.md b/README.md
index d33b6dc4b..2411cea3d 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,19 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
+### September 15 2025: WanGP v8.6 - Attack of the Clones
+
+- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
+
+- **First Frame / Last Frame for Vace** : Vace model are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias :
+
+For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on ....
+If you *Continue a Video* , you just need *"L L L"* since the the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*.
+
+- **Qwen Inpainting** exist now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area.
+
+- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video.
+
### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ?
I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof.
@@ -74,221 +87,6 @@ You will find below a 33s movie I have created using these two methods. Quality
*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs*
-### August 29 2025: WanGP v8.21 - Here Goes Your Weekend
-
-- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
-
-- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
-
-- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
-
-### August 24 2025: WanGP v8.1 - the RAM Liberator
-
-- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
-- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
-If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
-- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
-
-### August 21 2025: WanGP v8.01 - the killer of seven
-
-- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
-- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
-- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
-- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
-- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
-
-*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
-### August 12 2025: WanGP v7.7777 - Lucky Day(s)
-
-This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything !
-
-Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing.
-
-Support:
-- Video: x264, x264 lossless, x265
-- Images: jpeg, png, webp, wbp lossless
-Generation Settings are stored in each of the above regardless of the format (that was the hard part).
-
-Also you can now choose different output directories for images and videos.
-
-unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers.
-*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time*
-*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?*
-- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps)
-- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps)
-
-
-### August 10 2025: WanGP v7.76 - Faster than the VAE ...
-We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow...
-Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
-
-*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)*
-
-### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2
-Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
-
-Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster !
-
-### August 8 2025: WanGP v7.73 - Qwen Rebirth
-Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
-
-As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before)
-
-Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures.
-
-*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case*
-
-
-### August 6 2025: WanGP v7.71 - Picky, picky
-
-This release comes with two new models :
-- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
-- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
-
-There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
-
-*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
-
-
-### August 4 2025: WanGP v7.6 - Remuxed
-
-With this new version you won't have any excuse if there is no sound in your video.
-
-*Continue Video* now works with any video that has already some sound (hint: Multitalk ).
-
-Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack.
-
-As a result you can apply a different sound source on each new video segment when doing a *Continue Video*.
-
-For instance:
-- first video part: use Multitalk with two people speaking
-- second video part: you apply your own soundtrack which will gently follow the multitalk conversation
-- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio
-
-To multiply the combinations I have also implemented *Continue Video* with the various image2video models.
-
-Also:
-- End Frame support added for LTX Video models
-- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides
-- Flux Krea Dev support
-
-### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2
-Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ...
-
-Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder.
-
-I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference...
-
-And this time I really removed Vace Cocktail Light which gave a blurry vision.
-
-### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview
-Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
-
-So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
-
-However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** !
-
-Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation
-
-Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
-
-7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well.
-
-### July 27 2025: WanGP v7.3 : Interlude
-While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.
-
-### July 26 2025: WanGP v7.2 : Ode to Vace
-I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
-
-Here are some new Vace improvements:
-- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX.
-- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*).
-- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
-- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
-
-Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
-
-
-### July 21 2025: WanGP v7.12
-- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
-
-- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
-
-- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
-
-- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
-Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
-
-- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update)
-
-- Easier way to select video resolution
-
-
-### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
-This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
-- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB
-- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer
-- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ...
-- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation
-
-And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\
-As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\
-This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization.
-
-WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras.
-
-Also in the news:
-- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected.
-- *Film Grain* post processing to add a vintage look at your video
-- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete
-- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated.
-
-
-### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me
-Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase.
-
-So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models.
-
-Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters.
-
-The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence.
-
-### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** :
-**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation).
-
-Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus.
-
-And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people.
-
-As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors.
-
-But wait, there is more:
-- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you)
-- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes.
-
-Also, of interest too:
-- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos
-- Force the generated video fps to your liking, works wery well with Vace when using a Control Video
-- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time)
-
-### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features:
-- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations
-- In one click use the newly generated video as a Control Video or Source Video to be continued
-- Manage multiple settings for the same model and switch between them using a dropdown box
-- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open
-- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder)
-
-Taking care of your life is not enough, you want new stuff to play with ?
-- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio
-- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best
-- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps
-- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage
-- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video
-- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer
-- Choice of Wan Samplers / Schedulers
-- More Lora formats support
-
-**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one**
See full changelog: **[Changelog](docs/CHANGELOG.md)**
diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json
index 79b8b245d..04fc573fb 100644
--- a/defaults/qwen_image_edit_20B.json
+++ b/defaults/qwen_image_edit_20B.json
@@ -7,6 +7,7 @@
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors"
],
+ "preload_URLs": ["https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_inpainting.safetensors"],
"attention": {
"<89": "sdpa"
}
diff --git a/defaults/vace_fun_14B_2_2.json b/defaults/vace_fun_14B_2_2.json
new file mode 100644
index 000000000..8f22d3456
--- /dev/null
+++ b/defaults/vace_fun_14B_2_2.json
@@ -0,0 +1,24 @@
+{
+ "model": {
+ "name": "Wan2.2 Vace Fun 14B",
+ "architecture": "vace_14B",
+ "description": "This is the Fun Vace 2.2 version, that is not the official Vace 2.2",
+ "URLs": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_mbf16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mbf16_int8.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mfp16_int8.safetensors"
+ ],
+ "URLs2": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_mbf16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mbf16_int8.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mfp16_int8.safetensors"
+ ],
+ "group": "wan2_2"
+ },
+ "guidance_phases": 2,
+ "num_inference_steps": 30,
+ "guidance_scale": 1,
+ "guidance2_scale": 1,
+ "flow_shift": 2,
+ "switch_threshold": 875
+}
\ No newline at end of file
diff --git a/defaults/vace_fun_14B_cocktail_2_2.json b/defaults/vace_fun_14B_cocktail_2_2.json
new file mode 100644
index 000000000..c587abdd4
--- /dev/null
+++ b/defaults/vace_fun_14B_cocktail_2_2.json
@@ -0,0 +1,28 @@
+{
+ "model": {
+ "name": "Wan2.2 Vace Fun Cocktail 14B",
+ "architecture": "vace_14B",
+ "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. This is the Fun Vace 2.2, that is not the official Vace 2.2",
+ "URLs": "vace_fun_14B_2_2",
+ "URLs2": "vace_fun_14B_2_2",
+ "loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
+ ],
+ "loras_multipliers": [
+ 1,
+ 0.2,
+ 0.5,
+ 0.5
+ ],
+ "group": "wan2_2"
+ },
+ "guidance_phases": 2,
+ "num_inference_steps": 10,
+ "guidance_scale": 1,
+ "guidance2_scale": 1,
+ "flow_shift": 2,
+ "switch_threshold": 875
+}
\ No newline at end of file
diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md
index 5a89d9389..b0eeae34b 100644
--- a/docs/CHANGELOG.md
+++ b/docs/CHANGELOG.md
@@ -1,20 +1,154 @@
# Changelog
## 🔥 Latest News
-### July 21 2025: WanGP v7.1
+### August 29 2025: WanGP v8.21 - Here Goes Your Weekend
+
+- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended)
+
+- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn
+
+- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis
+
+### August 24 2025: WanGP v8.1 - the RAM Liberator
+
+- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP
+- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\
+If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators.
+- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps
+
+### August 21 2025: WanGP v8.01 - the killer of seven
+
+- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good.
+- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt.
+- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work
+- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading
+- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1.
+
+*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation*
+### August 12 2025: WanGP v7.7777 - Lucky Day(s)
+
+This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything !
+
+Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing.
+
+Support:
+- Video: x264, x264 lossless, x265
+- Images: jpeg, png, webp, wbp lossless
+Generation Settings are stored in each of the above regardless of the format (that was the hard part).
+
+Also you can now choose different output directories for images and videos.
+
+unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers.
+*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time*
+*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?*
+- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps)
+- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps)
+
+
+### August 10 2025: WanGP v7.76 - Faster than the VAE ...
+We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow...
+Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune.
+
+*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)*
+
+### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2
+Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1.
+
+Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster !
+
+### August 8 2025: WanGP v7.73 - Qwen Rebirth
+Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him.
+
+As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before)
+
+Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures.
+
+*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case*
+
+
+### August 6 2025: WanGP v7.71 - Picky, picky
+
+This release comes with two new models :
+- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals
+- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" )
+
+There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p.
+
+*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.*
+
+
+### August 4 2025: WanGP v7.6 - Remuxed
+
+With this new version you won't have any excuse if there is no sound in your video.
+
+*Continue Video* now works with any video that has already some sound (hint: Multitalk ).
+
+Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack.
+
+As a result you can apply a different sound source on each new video segment when doing a *Continue Video*.
+
+For instance:
+- first video part: use Multitalk with two people speaking
+- second video part: you apply your own soundtrack which will gently follow the multitalk conversation
+- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio
+
+To multiply the combinations I have also implemented *Continue Video* with the various image2video models.
+
+Also:
+- End Frame support added for LTX Video models
+- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides
+- Flux Krea Dev support
+
+### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2
+Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ...
+
+Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder.
+
+I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference...
+
+And this time I really removed Vace Cocktail Light which gave a blurry vision.
+
+### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview
+Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters.
+
+So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment.
+
+However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** !
+
+Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation
+
+Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan...
+
+7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well.
+
+### July 27 2025: WanGP v7.3 : Interlude
+While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family.
+
+### July 26 2025: WanGP v7.2 : Ode to Vace
+I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk.
+
+Here are some new Vace improvements:
+- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX.
+- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*).
+- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged.
+- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code.
+
+Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats.
+
+
+### July 21 2025: WanGP v7.12
- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added.
-- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment
+- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App)
- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them.
- LTX IC-Lora support: these are special Loras that consumes a conditional image or video
Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it.
-And Also:
-- easier way to select video resolution
-- started to optimize Matanyone to reduce VRAM requirements
+- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update)
+- Easier way to select video resolution
### July 15 2025: WanGP v7.0 is an AI Powered Photoshop
This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame :
diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py
index 808369ff6..d16888122 100644
--- a/models/flux/flux_handler.py
+++ b/models/flux/flux_handler.py
@@ -107,7 +107,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder
]
@staticmethod
- def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
+ def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .flux_main import model_factory
flux_model = model_factory(
diff --git a/models/flux/sampling.py b/models/flux/sampling.py
index 1b4813a5c..92ee59099 100644
--- a/models/flux/sampling.py
+++ b/models/flux/sampling.py
@@ -203,7 +203,7 @@ def prepare_kontext(
image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
- convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
+ # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
return_dict.update({
"img_msk_latents": image_mask_latents,
diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py
index 67e9e99da..435234102 100644
--- a/models/hyvideo/hunyuan_handler.py
+++ b/models/hyvideo/hunyuan_handler.py
@@ -68,7 +68,13 @@ def query_model_def(base_model_type, model_def):
"visible": False,
}
- if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True
+ if base_model_type in ["hunyuan_avatar"]:
+ extra_model_def["image_ref_choices"] = {
+ "choices": [("Start Image", "KI")],
+ "letters_filter":"KI",
+ "visible": False,
+ }
+ extra_model_def["no_background_removal"] = True
if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]:
extra_model_def["one_image_ref_needed"] = True
@@ -123,7 +129,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder
}
@staticmethod
- def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
+ def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .hunyuan import HunyuanVideoSampler
from mmgp import offload
diff --git a/models/ltx_video/ltxv.py b/models/ltx_video/ltxv.py
index db143fc6d..080860c14 100644
--- a/models/ltx_video/ltxv.py
+++ b/models/ltx_video/ltxv.py
@@ -476,14 +476,14 @@ def generate(
images = images.sub_(0.5).mul_(2).squeeze(0)
return images
- def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs):
+ def get_loras_transformer(self, get_model_recursive_prop, model_type, video_prompt_type, **kwargs):
map = {
"P" : "pose",
"D" : "depth",
"E" : "canny",
}
loras = []
- preloadURLs = get_model_recursive_prop(self.model_type, "preload_URLs")
+ preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
lora_file_name = ""
for letter, signature in map.items():
if letter in video_prompt_type:
diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py
index e44c98399..2845fdb0b 100644
--- a/models/ltx_video/ltxv_handler.py
+++ b/models/ltx_video/ltxv_handler.py
@@ -74,7 +74,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder
@staticmethod
- def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
+ def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .ltxv import LTXV
ltxv_model = LTXV(
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index 0897ee419..be982aabc 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -569,6 +569,8 @@ def __call__(
pipeline=None,
loras_slists=None,
joint_pass= True,
+ lora_inpaint = False,
+ outpainting_dims = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -704,7 +706,7 @@ def __call__(
image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
if (image_width,image_height) != image.size:
image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
- else:
+ elif not lora_inpaint:
# _, image_width, image_height = min(
# (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
# )
@@ -721,8 +723,16 @@ def __call__(
if image.size != (image_width, image_height):
image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
+ image = convert_image_to_tensor(image)
+ if lora_inpaint:
+ image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1]
+ image_mask_latents = None
+ green = torch.tensor([-1.0, 1.0, -1.0]).to(image)
+ green_image = green[:, None, None] .expand_as(image)
+ image = torch.where(image_mask_rebuilt > 0, green_image, image)
+ prompt_image = convert_tensor_to_image(image)
+ image = image.unsqueeze(0).unsqueeze(2)
# image.save("nnn.png")
- image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2)
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
@@ -940,7 +950,7 @@ def __call__(
)
latents = latents / latents_std + latents_mean
output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
- if image_mask is not None:
+ if image_mask is not None and not lora_inpaint : #not (lora_inpaint and outpainting_dims is not None):
output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt
diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py
index 80a909f84..010298ead 100644
--- a/models/qwen/qwen_handler.py
+++ b/models/qwen/qwen_handler.py
@@ -1,4 +1,6 @@
import torch
+import gradio as gr
+
def get_qwen_text_encoder_filename(text_encoder_quantization):
text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors"
@@ -29,6 +31,16 @@ def query_model_def(base_model_type, model_def):
"letters_filter": "KI",
}
extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape"
+ extra_model_def["video_guide_outpainting"] = [2]
+ extra_model_def["model_modes"] = {
+ "choices": [
+ ("Lora Inpainting: Inpainted area completely unrelated to occulted content", 1),
+ ("Masked Denoising : Inpainted area may reuse some content that has been occulted", 0),
+ ],
+ "default": 1,
+ "label" : "Inpainting Method",
+ "image_modes" : [2],
+ }
return extra_model_def
@@ -58,7 +70,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder
}
@staticmethod
- def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False):
+ def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None):
from .qwen_main import model_factory
from mmgp import offload
@@ -99,5 +111,18 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"video_prompt_type": "KI",
"denoising_strength" : 1.,
+ "model_mode" : 0,
})
+ def validate_generative_settings(base_model_type, model_def, inputs):
+ if base_model_type in ["qwen_image_edit_20B"]:
+ model_mode = inputs["model_mode"]
+ denoising_strength= inputs["denoising_strength"]
+ video_guide_outpainting= inputs["video_guide_outpainting"]
+ from wgp import get_outpainting_dims
+ outpainting_dims = get_outpainting_dims(video_guide_outpainting)
+
+ if denoising_strength < 1 and model_mode == 1:
+ gr.Info("Denoising Strength will be ignored while using Lora Inpainting")
+ if outpainting_dims is not None and model_mode == 0 :
+ return "Outpainting is not supported with Masked Denoising "
diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py
index e4c19eddb..da84e0d0e 100644
--- a/models/qwen/qwen_main.py
+++ b/models/qwen/qwen_main.py
@@ -44,7 +44,7 @@ def __init__(
save_quantized = False,
dtype = torch.bfloat16,
VAE_dtype = torch.float32,
- mixed_precision_transformer = False
+ mixed_precision_transformer = False,
):
@@ -117,6 +117,8 @@ def generate(
joint_pass = True,
sample_solver='default',
denoising_strength = 1.,
+ model_mode = 0,
+ outpainting_dims = None,
**bbargs
):
# Generate with different aspect ratios
@@ -205,8 +207,16 @@ def generate(
loras_slists=loras_slists,
joint_pass = joint_pass,
denoising_strength=denoising_strength,
- generator=torch.Generator(device="cuda").manual_seed(seed)
+ generator=torch.Generator(device="cuda").manual_seed(seed),
+ lora_inpaint = image_mask is not None and model_mode == 1,
+ outpainting_dims = outpainting_dims,
)
if image is None: return None
return image.transpose(0, 1)
+ def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs):
+ if model_mode == 0: return [], []
+ preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
+ return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
+
+
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 40ed42fa5..dde1a6535 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -64,6 +64,7 @@ def __init__(
config,
checkpoint_dir,
model_filename = None,
+ submodel_no_list = None,
model_type = None,
model_def = None,
base_model_type = None,
@@ -126,50 +127,65 @@ def __init__(
forcedConfigPath = base_config_file if len(model_filename) > 1 else None
# forcedConfigPath = base_config_file = f"configs/flf2v_720p.json"
# model_filename[1] = xmodel_filename
-
+ self.model = self.model2 = None
source = model_def.get("source", None)
+ source2 = model_def.get("source2", None)
module_source = model_def.get("module_source", None)
+ module_source2 = model_def.get("module_source2", None)
if module_source is not None:
- model_filename = [] + model_filename
- model_filename[1] = module_source
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
- elif source is not None:
+ self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
+ if module_source2 is not None:
+ self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
+ if source is not None:
self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file)
- elif self.transformer_switch:
- shared_modules= {}
- self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules)
- self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
- shared_modules = None
+ if source2 is not None:
+ self.model2 = offload.fast_load_transformers_model(source2, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file)
+
+ if self.model is not None or self.model2 is not None:
+ from wgp import save_model
+ from mmgp.safetensors2 import torch_load_file
else:
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
+ if self.transformer_switch:
+ if 0 in submodel_no_list[2:] and 1 in submodel_no_list:
+ raise Exception("Shared and non shared modules at the same time across multipe models is not supported")
+
+ if 0 in submodel_no_list[2:]:
+ shared_modules= {}
+ self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules)
+ self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
+ shared_modules = None
+ else:
+ modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ]
+ modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ]
+ self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
+ self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
+
+ else:
+ self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath)
- # self.model = offload.load_model_data(self.model, xmodel_filename )
- # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth")
- self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
- offload.change_dtype(self.model, dtype, True)
+ if self.model is not None:
+ self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
+ offload.change_dtype(self.model, dtype, True)
+ self.model.eval().requires_grad_(False)
if self.model2 is not None:
self.model2.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype)
offload.change_dtype(self.model2, dtype, True)
-
- # offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd)
- # offload.save_model(self.model, "wan2.2_image2video_14B_low_mbf16.safetensors", config_file_path=base_config_file)
- # offload.save_model(self.model, "wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", do_quantize=True, config_file_path=base_config_file)
- self.model.eval().requires_grad_(False)
- if self.model2 is not None:
self.model2.eval().requires_grad_(False)
+
if module_source is not None:
- from wgp import save_model
- from mmgp.safetensors2 import torch_load_file
- filter = list(torch_load_file(module_source))
- save_model(self.model, model_type, dtype, None, is_module=True, filter=filter)
- elif not source is None:
- from wgp import save_model
- save_model(self.model, model_type, dtype, None)
+ save_model(self.model, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source)), module_source_no=1)
+ if module_source2 is not None:
+ save_model(self.model2, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source2)), module_source_no=2)
+ if not source is None:
+ save_model(self.model, model_type, dtype, None, submodel_no= 1)
+ if not source2 is None:
+ save_model(self.model2, model_type, dtype, None, submodel_no= 2)
if save_quantized:
from wgp import save_quantized_model
- save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
+ if self.model is not None:
+ save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file)
if self.model2 is not None:
save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2)
self.sample_neg_prompt = config.sample_neg_prompt
@@ -307,7 +323,7 @@ def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_
canvas = canvas.to(device)
return ref_img.to(device), canvas
- def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
+ def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
image_sizes = []
trim_video_guide = len(keep_video_guide_frames)
def conv_tensor(t, device):
@@ -659,13 +675,15 @@ def generate(self,
inject_from_start = False
if input_frames != None and denoising_strength < 1 :
color_reference_frame = input_frames[:, -1:].clone()
- if overlapped_latents != None:
- overlapped_latents_frames_num = overlapped_latents.shape[2]
- overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
+ if prefix_frames_count > 0:
+ overlapped_frames_num = prefix_frames_count
+ overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1
+ # overlapped_latents_frames_num = overlapped_latents.shape[2]
+ # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
else:
overlapped_latents_frames_num = overlapped_frames_num = 0
if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = []
- injection_denoising_step = int(sampling_steps * (1. - denoising_strength) )
+ injection_denoising_step = int( round(sampling_steps * (1. - denoising_strength),4) )
latent_keep_frames = []
if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0:
inject_from_start = True
diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py
index 39d0a70c8..7a5d3ea2d 100644
--- a/models/wan/df_handler.py
+++ b/models/wan/df_handler.py
@@ -78,7 +78,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder
return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization)
@staticmethod
- def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
+ def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
from .configs import WAN_CONFIGS
from .wan_handler import family_handler
cfg = WAN_CONFIGS['t2v-14B']
diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py
index b04f65f32..52d2dd995 100644
--- a/models/wan/multitalk/multitalk.py
+++ b/models/wan/multitalk/multitalk.py
@@ -214,18 +214,20 @@ def process_tts_multi(text, save_dir, voice1, voice2):
return s1, s2, save_path_sum
-def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0):
+def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False):
wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base")
# wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec")
pad = int(padded_frames_for_embeddings/ fps * sr)
new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration )
- audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
- audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
- full_audio_embs = []
- if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
- # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
- if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
- if audio_guide2 == None and not duration_changed: sum_human_speechs = None
+ if return_sum_only:
+ full_audio_embs = None
+ else:
+ audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
+ audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps)
+ full_audio_embs = []
+ if audio_guide1 != None: full_audio_embs.append(audio_embedding_1)
+ if audio_guide2 != None: full_audio_embs.append(audio_embedding_2)
+ if audio_guide2 == None and not duration_changed: sum_human_speechs = None
return full_audio_embs, sum_human_speechs
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 253bd9236..6586176bd 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -166,7 +166,8 @@ def query_model_def(base_model_type, model_def):
extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames"
-
+ extra_model_def["video_guide_outpainting"] = [0,1]
+
if base_model_type in ["standin"]:
extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["image_ref_choices"] = {
@@ -293,7 +294,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder
@staticmethod
- def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False):
+ def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None):
from .configs import WAN_CONFIGS
if test_class_i2v(base_model_type):
@@ -306,6 +307,7 @@ def load_model(model_filename, model_type, base_model_type, model_def, quantizeT
config=cfg,
checkpoint_dir="ckpts",
model_filename=model_filename,
+ submodel_no_list = submodel_no_list,
model_type = model_type,
model_def = model_def,
base_model_type=base_model_type,
@@ -381,7 +383,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
if base_model_type in ["fantasy"]:
ui_defaults.update({
"audio_guidance_scale": 5.0,
- "sliding_window_size": 1,
+ "sliding_window_overlap" : 1,
})
elif base_model_type in ["multitalk"]:
@@ -398,6 +400,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"guidance_scale": 5.0,
"flow_shift": 7, # 11 for 720p
"sliding_window_overlap" : 9,
+ "sliding_window_size": 81,
"sample_solver" : "euler",
"video_prompt_type": "QKI",
"remove_background_images_ref" : 0,
diff --git a/preprocessing/extract_vocals.py b/preprocessing/extract_vocals.py
new file mode 100644
index 000000000..656402644
--- /dev/null
+++ b/preprocessing/extract_vocals.py
@@ -0,0 +1,69 @@
+from pathlib import Path
+import os, tempfile
+import numpy as np
+import soundfile as sf
+import librosa
+import torch
+import gc
+
+from audio_separator.separator import Separator
+
+def get_vocals(src_path: str, dst_path: str, min_seconds: float = 8) -> str:
+ """
+ If the source audio is shorter than `min_seconds`, pad with trailing silence
+ in a temporary file, then run separation and save only the vocals to dst_path.
+ Returns the full path to the vocals file.
+ """
+
+ default_device = torch.get_default_device()
+ torch.set_default_device('cpu')
+
+ dst = Path(dst_path)
+ dst.parent.mkdir(parents=True, exist_ok=True)
+
+ # Quick duration check
+ duration = librosa.get_duration(path=src_path)
+
+ use_path = src_path
+ temp_path = None
+ try:
+ if duration < min_seconds:
+ # Load (resample) and pad in memory
+ y, sr = librosa.load(src_path, sr=None, mono=False)
+ if y.ndim == 1: # ensure shape (channels, samples)
+ y = y[np.newaxis, :]
+ target_len = int(min_seconds * sr)
+ pad = max(0, target_len - y.shape[1])
+ if pad:
+ y = np.pad(y, ((0, 0), (0, pad)), mode="constant")
+
+ # Write a temp WAV for the separator
+ fd, temp_path = tempfile.mkstemp(suffix=".wav")
+ os.close(fd)
+ sf.write(temp_path, y.T, sr) # soundfile expects (frames, channels)
+ use_path = temp_path
+
+ # Run separation: emit only the vocals, with your exact filename
+ sep = Separator(
+ output_dir=str(dst.parent),
+ output_format=(dst.suffix.lstrip(".") or "wav"),
+ output_single_stem="Vocals",
+ model_file_dir="ckpts/roformer/" #model_bs_roformer_ep_317_sdr_12.9755.ckpt"
+ )
+ sep.load_model()
+ out_files = sep.separate(use_path, {"Vocals": dst.stem})
+
+ out = Path(out_files[0])
+ return str(out if out.is_absolute() else (dst.parent / out))
+ finally:
+ if temp_path and os.path.exists(temp_path):
+ os.remove(temp_path)
+
+ torch.cuda.empty_cache()
+ gc.collect()
+ torch.set_default_device(default_device)
+
+# Example:
+# final = extract_vocals("in/clip.mp3", "out/vocals.wav")
+# print(final)
+
diff --git a/preprocessing/speakers_separator.py b/preprocessing/speakers_separator.py
index 79cde9b1e..d5f256335 100644
--- a/preprocessing/speakers_separator.py
+++ b/preprocessing/speakers_separator.py
@@ -100,7 +100,7 @@ def __init__(self, hf_token: str = None, local_model_path: str = None,
self.hf_token = hf_token
self._overlap_pipeline = None
- def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]:
+ def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None ) -> Dict[str, str]:
"""Optimized main separation function with memory management."""
xprint("Starting optimized audio separation...")
self._current_audio_path = os.path.abspath(audio_path)
@@ -128,7 +128,11 @@ def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]:
gc.collect()
# Save outputs efficiently
- output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2)
+ if audio_original_path is None:
+ waveform_original = waveform
+ else:
+ waveform_original, sample_rate = self.load_audio(audio_original_path)
+ output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2)
return output_paths
@@ -835,7 +839,7 @@ def print_summary(self, audio_path: str):
for turn, _, speaker in diarization.itertracks(yield_label=True):
xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s")
-def extract_dual_audio(audio, output1, output2, verbose = False):
+def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None):
global verbose_output
verbose_output = verbose
separator = OptimizedPyannote31SpeakerSeparator(
@@ -848,7 +852,7 @@ def extract_dual_audio(audio, output1, output2, verbose = False):
import time
start_time = time.time()
- outputs = separator.separate_audio(audio, output1, output2)
+ outputs = separator.separate_audio(audio, output1, output2, audio_original)
elapsed_time = time.time() - start_time
xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===")
diff --git a/requirements.txt b/requirements.txt
index 1f008c75f..e6c85fda7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -21,14 +21,15 @@ mutagen
pyloudnorm
librosa==0.11.0
speechbrain==1.0.3
-
+audio-separator==0.36.1
+
# UI & interaction
gradio==5.29.0
dashscope
loguru
# Vision & segmentation
-opencv-python>=4.9.0.80
+opencv-python>=4.12.0.88
segment-anything
rembg[gpu]==2.0.65
onnxruntime-gpu
diff --git a/wgp.py b/wgp.py
index d085be76c..de8e2f66f 100644
--- a/wgp.py
+++ b/wgp.py
@@ -23,6 +23,7 @@
from shared.utils import notification_sound
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
+from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
from shared.utils.audio_video import save_image_metadata, read_image_metadata
from shared.match_archi import match_nvidia_architecture
@@ -61,8 +62,8 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.55"
-settings_version = 2.34
+WanGP_version = "8.6"
+settings_version = 2.35
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -199,9 +200,11 @@ def clean_image_list(gradio_list):
def process_prompt_and_add_tasks(state, model_choice):
-
+ def ret():
+ return gr.update(), gr.update()
+
if state.get("validate_success",0) != 1:
- return
+ ret()
state["validate_success"] = 0
model_filename = state["model_filename"]
@@ -218,7 +221,7 @@ def process_prompt_and_add_tasks(state, model_choice):
if inputs == None:
gr.Warning("Internal state error: Could not retrieve inputs for the model.")
queue = gen.get("queue", [])
- return get_queue_table(queue)
+ return ret()
model_def = get_model_def(model_type)
model_handler = get_model_handler(model_type)
image_outputs = inputs["image_mode"] > 0
@@ -247,7 +250,7 @@ def process_prompt_and_add_tasks(state, model_choice):
if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"]
if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0:
gr.Info("Temporal Upsampling can not be used with an Image")
- return
+ return ret()
film_grain_intensity = inputs.get("film_grain_intensity",0)
film_grain_saturation = inputs.get("film_grain_saturation",0.5)
# if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"]
@@ -263,7 +266,7 @@ def process_prompt_and_add_tasks(state, model_choice):
else:
if audio_source is None:
gr.Info("You must provide a custom Audio")
- return
+ return ret()
prompt += ["Custom Audio"]
repeat_generation == 1
@@ -273,32 +276,32 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("You must choose at least one Remux Method")
else:
gr.Info("You must choose at least one Post Processing Method")
- return
+ return ret()
inputs["prompt"] = ", ".join(prompt)
add_video_task(**inputs)
gen["prompts_max"] = 1 + gen.get("prompts_max",0)
state["validate_success"] = 1
queue= gen.get("queue", [])
- return update_queue_data(queue)
+ return ret()
if hasattr(model_handler, "validate_generative_settings"):
error = model_handler.validate_generative_settings(model_type, model_def, inputs)
if error is not None and len(error) > 0:
gr.Info(error)
- return
+ return ret()
if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0:
gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time")
- return
+ return ret()
prompt = inputs["prompt"]
if len(prompt) ==0:
gr.Info("Prompt cannot be empty.")
gen = get_gen_info(state)
queue = gen.get("queue", [])
- return get_queue_table(queue)
+ return ret()
prompt, errors = prompt_parser.process_template(prompt)
if len(errors) > 0:
gr.Info("Error processing prompt template: " + errors)
- return
+ return ret()
model_filename = get_model_filename(model_type)
prompts = prompt.replace("\r", "").split("\n")
prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")]
@@ -306,7 +309,7 @@ def process_prompt_and_add_tasks(state, model_choice):
gr.Info("Prompt cannot be empty.")
gen = get_gen_info(state)
queue = gen.get("queue", [])
- return get_queue_table(queue)
+ return ret()
resolution = inputs["resolution"]
width, height = resolution.split("x")
@@ -360,22 +363,22 @@ def process_prompt_and_add_tasks(state, model_choice):
_, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases)
if len(errors) > 0:
gr.Info(f"Error parsing Loras Multipliers: {errors}")
- return
+ return ret()
if guidance_phases == 3:
if switch_threshold < switch_threshold2:
gr.Info(f"Phase 1-2 Switch Noise Level ({switch_threshold}) should be Greater than Phase 2-3 Switch Noise Level ({switch_threshold2}). As a reminder, noise will gradually go down from 1000 to 0.")
- return
+ return ret()
else:
model_switch_phase = 1
if not any_steps_skipping: skip_steps_cache_type = ""
if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20:
gr.Info("The minimum number of steps should be 20")
- return
+ return ret()
if skip_steps_cache_type == "mag":
if num_inference_steps > 50:
gr.Info("Mag Cache maximum number of steps is 50")
- return
+ return ret()
if image_mode > 0:
audio_prompt_type = ""
@@ -385,7 +388,7 @@ def process_prompt_and_add_tasks(state, model_choice):
speakers_bboxes, error = parse_speakers_locations(speakers_locations)
if len(error) > 0:
gr.Info(error)
- return
+ return ret()
if MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and video_length <16: #should depend on the architecture
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
@@ -393,23 +396,24 @@ def process_prompt_and_add_tasks(state, model_choice):
if len(frames_positions.strip()) > 0:
positions = frames_positions.split(" ")
for pos_str in positions:
- if not is_integer(pos_str):
- gr.Info(f"Invalid Frame Position '{pos_str}'")
- return
- pos = int(pos_str)
- if pos <1 or pos > max_source_video_frames:
- gr.Info(f"Invalid Frame Position Value'{pos_str}'")
- return
+ if not pos_str in ["L", "l"] and len(pos_str)>0:
+ if not is_integer(pos_str):
+ gr.Info(f"Invalid Frame Position '{pos_str}'")
+ return ret()
+ pos = int(pos_str)
+ if pos <1 or pos > max_source_video_frames:
+ gr.Info(f"Invalid Frame Position Value'{pos_str}'")
+ return ret()
else:
frames_positions = None
if audio_source is not None and MMAudio_setting != 0:
gr.Info("MMAudio and Custom Audio Soundtrack can't not be used at the same time")
- return
+ return ret()
if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0:
if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0:
gr.Info("The number of frames to keep must be a non null integer")
- return
+ return ret()
else:
keep_frames_video_source = ""
@@ -419,18 +423,18 @@ def process_prompt_and_add_tasks(state, model_choice):
if "V" in image_prompt_type:
if video_source == None:
gr.Info("You must provide a Source Video file to continue")
- return
+ return ret()
else:
video_source = None
if "A" in audio_prompt_type:
if audio_guide == None:
gr.Info("You must provide an Audio Source")
- return
+ return ret()
if "B" in audio_prompt_type:
if audio_guide2 == None:
gr.Info("You must provide a second Audio Source")
- return
+ return ret()
else:
audio_guide2 = None
else:
@@ -444,23 +448,23 @@ def process_prompt_and_add_tasks(state, model_choice):
if model_def.get("one_image_ref_needed", False):
if image_refs == None :
gr.Info("You must provide an Image Reference")
- return
+ return ret()
if len(image_refs) > 1:
gr.Info("Only one Image Reference (a person) is supported for the moment by this model")
- return
+ return ret()
if model_def.get("at_least_one_image_ref_needed", False):
if image_refs == None :
gr.Info("You must provide at least one Image Reference")
- return
+ return ret()
if "I" in video_prompt_type:
if image_refs == None or len(image_refs) == 0:
gr.Info("You must provide at least one Reference Image")
- return
+ return ret()
image_refs = clean_image_list(image_refs)
if image_refs == None :
gr.Info("A Reference Image should be an Image")
- return
+ return ret()
else:
image_refs = None
@@ -468,35 +472,36 @@ def process_prompt_and_add_tasks(state, model_choice):
if image_outputs:
if image_guide is None:
gr.Info("You must provide a Control Image")
- return
+ return ret()
else:
if video_guide is None:
gr.Info("You must provide a Control Video")
- return
+ return ret()
if "A" in video_prompt_type and not "U" in video_prompt_type:
if image_outputs:
if image_mask is None:
gr.Info("You must provide a Image Mask")
- return
+ return ret()
else:
if video_mask is None:
gr.Info("You must provide a Video Mask")
- return
+ return ret()
else:
video_mask = None
image_mask = None
if "G" in video_prompt_type:
- gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ")
+ if denoising_strength < 1.:
+ gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(round(num_inference_steps * (1. - denoising_strength),4))} ")
else:
denoising_strength = 1.0
if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]:
gr.Info("Keep Frames for Control Video is not supported with LTX Video")
- return
+ return ret()
_, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length)
if len(error) > 0:
gr.Info(f"Invalid Keep Frames property: {error}")
- return
+ return ret()
else:
video_guide = None
image_guide = None
@@ -516,14 +521,14 @@ def process_prompt_and_add_tasks(state, model_choice):
if "S" in image_prompt_type:
if image_start == None or isinstance(image_start, list) and len(image_start) == 0:
gr.Info("You must provide a Start Image")
- return
+ return ret()
image_start = clean_image_list(image_start)
if image_start == None :
gr.Info("Start Image should be an Image")
- return
+ return ret()
if multi_prompts_gen_type == 1 and len(image_start) > 1:
gr.Info("Only one Start Image is supported")
- return
+ return ret()
else:
image_start = None
@@ -532,19 +537,19 @@ def process_prompt_and_add_tasks(state, model_choice):
if "E" in image_prompt_type:
if image_end == None or isinstance(image_end, list) and len(image_end) == 0:
gr.Info("You must provide an End Image")
- return
+ return ret()
image_end = clean_image_list(image_end)
if image_end == None :
gr.Info("End Image should be an Image")
- return
+ return ret()
if multi_prompts_gen_type == 0:
if video_source is not None:
if len(image_end)> 1:
gr.Info("If a Video is to be continued and the option 'Each Text Prompt Will create a new generated Video' is set, there can be only one End Image")
- return
+ return ret()
elif len(image_start or []) != len(image_end or []):
gr.Info("The number of Start and End Images should be the same when the option 'Each Text Prompt Will create a new generated Video'")
- return
+ return ret()
else:
image_end = None
@@ -553,7 +558,7 @@ def process_prompt_and_add_tasks(state, model_choice):
if video_length > sliding_window_size:
if model_type in ["t2v"] and not "G" in video_prompt_type :
gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.")
- return
+ return ret()
full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1
extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation"
no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap)
@@ -561,22 +566,22 @@ def process_prompt_and_add_tasks(state, model_choice):
if "recam" in model_filename:
if video_guide == None:
gr.Info("You must provide a Control Video")
- return
+ return ret()
computed_fps = get_computed_fps(force_fps, model_type , video_guide, video_source )
frames = get_resampled_video(video_guide, 0, 81, computed_fps)
if len(frames)<81:
gr.Info(f"Recammaster Control video should be at least 81 frames once the resampling at {computed_fps} fps has been done")
- return
+ return ret()
if "hunyuan_custom_custom_edit" in model_filename:
if len(keep_frames_video_guide) > 0:
gr.Info("Filtering Frames with this model is not supported")
- return
+ return ret()
if inputs["multi_prompts_gen_type"] != 0:
if image_start != None and len(image_start) > 1:
gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows")
- return
+ return ret()
# if image_end != None and len(image_end) > 1:
# gr.Info("Only one End Image must be provided if multiple prompts are used for different windows")
@@ -624,7 +629,7 @@ def process_prompt_and_add_tasks(state, model_choice):
if len(prompts) >= len(image_start):
if len(prompts) % len(image_start) != 0:
gr.Info("If there are more text prompts than input images the number of text prompts should be dividable by the number of images")
- return
+ return ret()
rep = len(prompts) // len(image_start)
new_image_start = []
new_image_end = []
@@ -638,7 +643,7 @@ def process_prompt_and_add_tasks(state, model_choice):
else:
if len(image_start) % len(prompts) !=0:
gr.Info("If there are more input images than text prompts the number of images should be dividable by the number of text prompts")
- return
+ return ret()
rep = len(image_start) // len(prompts)
new_prompts = []
for i, _ in enumerate(image_start):
@@ -666,11 +671,13 @@ def process_prompt_and_add_tasks(state, model_choice):
override_inputs["prompt"] = "\n".join(prompts)
inputs.update(override_inputs)
add_video_task(**inputs)
-
- gen["prompts_max"] = new_prompts_count + gen.get("prompts_max",0)
+ new_prompts_count += gen.get("prompts_max",0)
+ gen["prompts_max"] = new_prompts_count
state["validate_success"] = 1
queue= gen.get("queue", [])
- return update_queue_data(queue)
+ first_time_in_queue = state.get("first_time_in_queue", True)
+ state["first_time_in_queue"] = True
+ return update_queue_data(queue, first_time_in_queue), gr.update(open=True) if new_prompts_count > 1 else gr.update()
def get_preview_images(inputs):
inputs_to_query = ["image_start", "video_source", "image_end", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ]
@@ -724,7 +731,6 @@ def add_video_task(**inputs):
"start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None,
"end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None
})
- return update_queue_data(queue)
def update_task_thumbnails(task, inputs):
start_image_data, end_image_data, start_labels, end_labels = get_preview_images(inputs)
@@ -1341,16 +1347,12 @@ def get_queue_table(queue):
"✖"
])
return data
-def update_queue_data(queue):
+def update_queue_data(queue, first_time_in_queue =False):
update_global_queue_ref(queue)
data = get_queue_table(queue)
- if len(data) == 0:
- return gr.DataFrame(value=[], max_height=1)
- elif len(data) == 1:
- return gr.DataFrame(value=data, max_height= 83)
- else:
- return gr.DataFrame(value=data, max_height= 1000)
+ return gr.DataFrame(value=data)
+
def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True):
bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom"
@@ -2004,8 +2006,10 @@ def get_model_recursive_prop(model_type, prop = "URLs", sub_prop_name = None, re
raise Exception(f"Unknown model type '{model_type}'")
-def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, stack=[]):
- if module_type is not None:
+def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, URLs = None, stack=[]):
+ if URLs is not None:
+ pass
+ elif module_type is not None:
base_model_type = get_base_model_type(model_type)
# model_type_handler = model_types_handlers[base_model_type]
# modules_files = model_type_handler.query_modules_files() if hasattr(model_type_handler, "query_modules_files") else {}
@@ -2032,7 +2036,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu
if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}")
return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs])
- choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ]
+ # choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ]
+ choices = URLs
if len(quantization) == 0:
quantization = "bf16"
@@ -2042,13 +2047,13 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu
raw_filename = choices[0]
else:
if quantization in ("int8", "fp8"):
- sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name]
+ sub_choices = [ name for name in choices if quantization in os.path.basename(name) or quantization.upper() in os.path.basename(name)]
else:
- sub_choices = [ name for name in choices if "quanto" not in name]
+ sub_choices = [ name for name in choices if "quanto" not in os.path.basename(name)]
if len(sub_choices) > 0:
dtype_str = "fp16" if dtype == torch.float16 else "bf16"
- new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name]
+ new_sub_choices = [ name for name in sub_choices if dtype_str in os.path.basename(name) or dtype_str.upper() in os.path.basename(name)]
sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices
raw_filename = sub_choices[0]
else:
@@ -2107,6 +2112,10 @@ def fix_settings(model_type, ui_defaults):
audio_prompt_type ="A"
ui_defaults["audio_prompt_type"] = audio_prompt_type
+ if settings_version < 2.35 and any_audio_track(base_model_type):
+ audio_prompt_type = audio_prompt_type or ""
+ audio_prompt_type += "V"
+ ui_defaults["audio_prompt_type"] = audio_prompt_type
video_prompt_type = ui_defaults.get("video_prompt_type", "")
@@ -2156,13 +2165,14 @@ def fix_settings(model_type, ui_defaults):
video_prompt_type = ui_defaults.get("video_prompt_type", "")
image_ref_choices_list = model_def.get("image_ref_choices", {}).get("choices", [])
- if len(image_ref_choices_list)==0:
- video_prompt_type = del_in_sequence(video_prompt_type, "IK")
- else:
- first_choice = image_ref_choices_list[0][1]
- if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I"
- if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K"
- ui_defaults["video_prompt_type"] = video_prompt_type
+ if model_def.get("guide_custom_choices", None) is None:
+ if len(image_ref_choices_list)==0:
+ video_prompt_type = del_in_sequence(video_prompt_type, "IK")
+ else:
+ first_choice = image_ref_choices_list[0][1]
+ if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I"
+ if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K"
+ ui_defaults["video_prompt_type"] = video_prompt_type
model_handler = get_model_handler(base_model_type)
if hasattr(model_handler, "fix_settings"):
@@ -2200,7 +2210,8 @@ def get_default_prompt(i2v):
"slg_switch": 0,
"slg_layers": [9],
"slg_start_perc": 10,
- "slg_end_perc": 90
+ "slg_end_perc": 90,
+ "audio_prompt_type": "V",
}
model_handler = get_model_handler(model_type)
model_handler.update_default_settings(base_model_type, model_def, ui_defaults)
@@ -2348,7 +2359,7 @@ def init_model_def(model_type, model_def):
lock_ui_compile = True
-def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True ):
+def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True, module_source_no = 1):
model_def = get_model_def(model_type)
# To save module and quantized modules
# 1) set Transformer Model Quantization Type to 16 bits
@@ -2359,10 +2370,10 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod
if model_def == None: return
if is_module:
url_key = "modules"
- source_key = "module_source"
+ source_key = "module_source" if module_source_no <=1 else "module_source2"
else:
url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no)
- source_key = "source"
+ source_key = "source" if submodel_no <=1 else "source2"
URLs= model_def.get(url_key, None)
if URLs is None: return
if isinstance(URLs, str):
@@ -2377,6 +2388,9 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod
print("Target Module files are missing")
return
URLs= URLs[0]
+ if isinstance(URLs, dict):
+ url_dict_key = "URLs" if module_source_no ==1 else "URLs2"
+ URLs = URLs[url_dict_key]
for url in URLs:
if "quanto" not in url and dtypestr in url:
model_filename = os.path.basename(url)
@@ -2410,8 +2424,12 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod
elif not os.path.isfile(quanto_filename):
offload.save_model(model, quanto_filename, config_file_path=config_file, do_quantize= True, filter_sd=filter)
print(f"New quantized file '{quanto_filename}' had been created for finetune Id '{model_type}'.")
- model_def[url_key][0].append(quanto_filename)
- saved_finetune_def["model"][url_key][0].append(quanto_filename)
+ if isinstance(model_def[url_key][0],dict):
+ model_def[url_key][0][url_dict_key].append(quanto_filename)
+ saved_finetune_def["model"][url_key][0][url_dict_key].append(quanto_filename)
+ else:
+ model_def[url_key][0].append(quanto_filename)
+ saved_finetune_def["model"][url_key][0].append(quanto_filename)
update_model_def = True
if update_model_def:
with open(finetune_file, "w", encoding="utf-8") as writer:
@@ -2466,6 +2484,14 @@ def preprocessor_wrapper(sd):
return preprocessor_wrapper
+def get_local_model_filename(model_filename):
+ if model_filename.startswith("http"):
+ local_model_filename = os.path.join("ckpts", os.path.basename(model_filename))
+ else:
+ local_model_filename = model_filename
+ return local_model_filename
+
+
def process_files_def(repoId, sourceFolderList, fileList):
targetRoot = "ckpts/"
@@ -2491,7 +2517,7 @@ def download_mmaudio():
}
process_files_def(**enhancer_def)
-def download_models(model_filename = None, model_type= None, module_type = None, submodel_no = 1):
+def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1):
def computeList(filename):
if filename == None:
return []
@@ -2506,11 +2532,12 @@ def computeList(filename):
shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1",
- "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "det_align", "" ],
+ "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "roformer", "pyannote", "det_align", "" ],
"fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"],
["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"],
["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
["config.json", "pytorch_model.bin", "preprocessor_config.json"],
+ ["model_bs_roformer_ep_317_sdr_12.9755.ckpt", "model_bs_roformer_ep_317_sdr_12.9755.yaml", "download_checks.json"],
["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], ["detface.pt"], [ "flownet.pkl" ] ]
}
process_files_def(**shared_def)
@@ -2558,37 +2585,25 @@ def download_file(url,filename):
base_model_type = get_base_model_type(model_type)
model_def = get_model_def(model_type)
- source = model_def.get("source", None)
- module_source = model_def.get("module_source", None)
+ any_source = ("source2" if submodel_no ==2 else "source") in model_def
+ any_module_source = ("module_source2" if submodel_no ==2 else "module_source") in model_def
model_type_handler = model_types_handlers[base_model_type]
-
- if source is not None and module_type is None or module_source is not None and module_type is not None:
+ local_model_filename = get_local_model_filename(model_filename)
+
+ if any_source and not module_type or any_module_source and module_type:
model_filename = None
else:
- if not os.path.isfile(model_filename):
- if module_type is not None:
- key_name = "modules"
- URLs = module_type
- if isinstance(module_type, str):
- URLs = get_model_recursive_prop(module_type, key_name, sub_prop_name="_list", return_list= False)
- else:
- key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}"
- URLs = get_model_recursive_prop(model_type, key_name, return_list= False)
- if isinstance(URLs, str):
- raise Exception("Missing model " + URLs)
- use_url = model_filename
- for url in URLs:
- if os.path.basename(model_filename) in url:
- use_url = url
- break
+ if not os.path.isfile(local_model_filename):
+ url = model_filename
+
if not url.startswith("http"):
- raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.")
+ raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.")
try:
- download_file(use_url, model_filename)
+ download_file(url, local_model_filename)
except Exception as e:
- if os.path.isfile(model_filename): os.remove(model_filename)
- raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'")
-
+ if os.path.isfile(local_model_filename): os.remove(local_model_filename)
+ raise Exception(f"'{url}' is invalid for Model '{local_model_filename}' : {str(e)}'")
+ if module_type: return
model_filename = None
preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True)
@@ -2614,6 +2629,7 @@ def download_file(url,filename):
except Exception as e:
if os.path.isfile(filename): os.remove(filename)
raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'")
+ if module_type: return
model_files = model_type_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization)
if not isinstance(model_files, list): model_files = [model_files]
for one_repo in model_files:
@@ -2800,7 +2816,8 @@ def load_models(model_type, override_profile = -1):
model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!!
else:
model_filename2 = None
- modules = get_model_recursive_prop(model_type, "modules", return_list= True)
+ modules = get_model_recursive_prop(model_type, "modules", return_list= True)
+ modules = [get_model_recursive_prop(module, "modules", sub_prop_name ="_list", return_list= True) if isinstance(module, str) else module for module in modules ]
if save_quantized and "quanto" in model_filename:
save_quantized = False
print("Need to provide a non quantized model to create a quantized model to be saved")
@@ -2825,27 +2842,40 @@ def load_models(model_type, override_profile = -1):
if model_filename2 != None:
model_file_list += [model_filename2]
model_type_list += [model_type]
- module_type_list += [None]
+ module_type_list += [False]
model_submodel_no_list += [2]
for module_type in modules:
- model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type))
- model_type_list.append(model_type)
- module_type_list.append(module_type)
- model_submodel_no_list.append(0)
+ if isinstance(module_type,dict):
+ URLs1 = module_type.get("URLs", None)
+ if URLs1 is None: raise Exception(f"No URLs defined for Module {module_type}")
+ model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs1))
+ URLs2 = module_type.get("URLs2", None)
+ if URLs2 is None: raise Exception(f"No URL2s defined for Module {module_type}")
+ model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs2))
+ model_type_list += [model_type] * 2
+ module_type_list += [True] * 2
+ model_submodel_no_list += [1,2]
+ else:
+ model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type))
+ model_type_list.append(model_type)
+ module_type_list.append(True)
+ model_submodel_no_list.append(0)
+ local_model_file_list= []
for filename, file_model_type, file_module_type, submodel_no in zip(model_file_list, model_type_list, module_type_list, model_submodel_no_list):
download_models(filename, file_model_type, file_module_type, submodel_no)
+ local_model_file_list.append( get_local_model_filename(filename) )
VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float
mixed_precision_transformer = server_config.get("mixed_precision","0") == "1"
transformer_type = None
- for submodel_no, filename in zip(model_submodel_no_list, model_file_list):
- if submodel_no>=1:
+ for module_type, filename in zip(module_type_list, local_model_file_list):
+ if module_type is None:
print(f"Loading Model '{filename}' ...")
else:
print(f"Loading Module '{filename}' ...")
wan_model, pipe = model_types_handlers[base_model_type].load_model(
- model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization,
- dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized)
+ local_model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization,
+ dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized, submodel_no_list = model_submodel_no_list, )
kwargs = {}
profile = init_pipe(pipe, kwargs, override_profile)
@@ -2883,7 +2913,7 @@ def generate_header(model_type, compile, attention_mode):
description_container = [""]
get_model_name(model_type, description_container)
- model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or ""
+ model_filename = os.path.basename(get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)) or ""
description = description_container[0]
header = f"
{description}
"
overridden_attention = get_overridden_attention(model_type)
@@ -3517,6 +3547,48 @@ def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='t
# print(f"frame nos: {frame_nos}")
return frames_list
+# def get_resampled_video(video_in, start_frame, max_frames, target_fps):
+# from torchvision.io import VideoReader
+# import torch
+# from shared.utils.utils import resample
+
+# vr = VideoReader(video_in, "video")
+# meta = vr.get_metadata()["video"]
+
+# fps = round(float(meta["fps"][0]))
+# duration_s = float(meta["duration"][0])
+# num_src_frames = int(round(duration_s * fps)) # robust length estimate
+
+# if max_frames < 0:
+# max_frames = max(int(num_src_frames / fps * target_fps + max_frames), 0)
+
+# frame_nos = resample(
+# fps, num_src_frames,
+# max_target_frames_count=max_frames,
+# target_fps=target_fps,
+# start_target_frame=start_frame
+# )
+# if len(frame_nos) == 0:
+# return torch.empty((0,)) # nothing to return
+
+# target_ts = [i / fps for i in frame_nos]
+
+# # Read forward once, grabbing frames when we pass each target timestamp
+# frames = []
+# vr.seek(target_ts[0])
+# idx = 0
+# tol = 0.5 / fps # half-frame tolerance
+# for frame in vr:
+# t = float(frame["pts"]) # seconds
+# if idx < len(target_ts) and t + tol >= target_ts[idx]:
+# frames.append(frame["data"].permute(1,2,0)) # Tensor [H, W, C]
+# idx += 1
+# if idx >= len(target_ts):
+# break
+
+# return frames
+
+
def get_preprocessor(process_type, inpaint_color):
if process_type=="pose":
from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
@@ -3605,11 +3677,29 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis
return results
-def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, block_size= 16, expand_scale = 2):
+def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127):
frame_width, frame_height = input_image.size
+ if fit_crop:
+ input_image = rescale_and_crop(input_image, width, height)
+ if input_mask is not None:
+ input_mask = rescale_and_crop(input_mask, width, height)
+ return input_image, input_mask
+
+ if outpainting_dims != None:
+ if fit_canvas != None:
+ frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
+ else:
+ frame_height, frame_width = height, width
+
if fit_canvas != None:
height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
+
+ if outpainting_dims != None:
+ final_height, final_width = height, width
+ height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
+
+ if fit_canvas != None or outpainting_dims != None:
input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS)
if input_mask is not None:
input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS)
@@ -3621,10 +3711,23 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva
input_mask = np.array(input_mask)
input_mask = op_expand(input_mask, kernel, iterations=3)
input_mask = Image.fromarray(input_mask)
+
+ if outpainting_dims != None:
+ inpaint_color = inpaint_color / 127.5-1
+ image = convert_image_to_tensor(input_image)
+ full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device)
+ full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image
+ input_image = convert_tensor_to_image(full_frame)
+
+ if input_mask is not None:
+ mask = convert_image_to_tensor(input_mask)
+ full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device)
+ full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask
+ input_mask = convert_tensor_to_image(full_frame)
+
return input_image, input_mask
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
- from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
def mask_to_xyxy_box(mask):
rows, cols = np.where(mask == 255)
@@ -4592,6 +4695,7 @@ def remove_temp_filenames(temp_filenames_list):
reuse_frames = min(sliding_window_size - 4, sliding_window_overlap)
else:
sliding_window = False
+ sliding_window_size = current_video_length
reuse_frames = 0
_, latent_size = get_model_min_frames_and_step(model_type)
@@ -4603,10 +4707,6 @@ def remove_temp_filenames(temp_filenames_list):
# Source Video or Start Image > Control Video > Image Ref (background or positioned frames only) > UI Width, Height
# Image Ref (non background and non positioned frames) are boxed in a white canvas in order to keep their own width/height ratio
frames_to_inject = []
- if image_refs is not None:
- frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions is not None and len(frames_positions)> 0 else []
- frames_positions_list = frames_positions_list[:len(image_refs)]
- nb_frames_positions = len(frames_positions_list)
any_background_ref = 0
if "K" in video_prompt_type:
any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1
@@ -4642,29 +4742,47 @@ def remove_temp_filenames(temp_filenames_list):
output_new_audio_data = None
output_new_audio_filepath = None
original_audio_guide = audio_guide
+ original_audio_guide2 = audio_guide2
audio_proj_split = None
audio_proj_full = None
audio_scale = None
audio_context_lens = None
if (fantasy or multitalk or hunyuan_avatar or hunyuan_custom_audio) and audio_guide != None:
from models.wan.fantasytalking.infer import parse_audio
+ from preprocessing.extract_vocals import get_vocals
import librosa
duration = librosa.get_duration(path=audio_guide)
combination_type = "add"
+ clean_audio_files = "V" in audio_prompt_type
if audio_guide2 is not None:
duration2 = librosa.get_duration(path=audio_guide2)
if "C" in audio_prompt_type: duration += duration2
else: duration = min(duration, duration2)
combination_type = "para" if "P" in audio_prompt_type else "add"
+ if clean_audio_files:
+ audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav"))
+ audio_guide2 = get_vocals(original_audio_guide2, get_available_filename(save_path, audio_guide2, "_clean2", ".wav"))
+ temp_filenames_list += [audio_guide, audio_guide2]
else:
if "X" in audio_prompt_type:
+ # dual speaker, voice separation
from preprocessing.speakers_separator import extract_dual_audio
combination_type = "para"
if args.save_speakers:
audio_guide, audio_guide2 = "speaker1.wav", "speaker2.wav"
else:
audio_guide, audio_guide2 = get_available_filename(save_path, audio_guide, "_tmp1", ".wav"), get_available_filename(save_path, audio_guide, "_tmp2", ".wav")
- extract_dual_audio(original_audio_guide, audio_guide, audio_guide2 )
+ temp_filenames_list += [audio_guide, audio_guide2]
+ if clean_audio_files:
+ clean_audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, original_audio_guide, "_clean", ".wav"))
+ temp_filenames_list += [clean_audio_guide]
+ extract_dual_audio(clean_audio_guide if clean_audio_files else original_audio_guide, audio_guide, audio_guide2)
+
+ elif clean_audio_files:
+ # Single Speaker
+ audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav"))
+ temp_filenames_list += [audio_guide]
+
output_new_audio_filepath = original_audio_guide
current_video_length = min(int(fps * duration //latent_size) * latent_size + latent_size + 1, current_video_length)
@@ -4676,10 +4794,10 @@ def remove_temp_filenames(temp_filenames_list):
# pad audio_proj_full if aligned to beginning of window to simulate source window overlap
min_audio_duration = current_video_length/fps if reset_control_aligment else video_source_duration + current_video_length/fps
audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration)
- if output_new_audio_data is not None: output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined
- if not args.save_speakers and "X" in audio_prompt_type:
- os.remove(audio_guide)
- os.remove(audio_guide2)
+ if output_new_audio_data is not None: # not none if modified
+ if clean_audio_files: # need to rebuild the sum of audios with original audio
+ _, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = original_audio_guide, audio_guide2= original_audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration, return_sum_only= True)
+ output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined
if hunyuan_custom_edit and video_guide != None:
import cv2
@@ -4687,8 +4805,6 @@ def remove_temp_filenames(temp_filenames_list):
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
current_video_length = min(current_video_length, length)
- if image_guide is not None:
- image_guide, image_mask = preprocess_image_with_mask(image_guide, image_mask, height, width, fit_canvas = None, block_size= block_size, expand_scale = mask_expand)
seed = set_seed(seed)
@@ -4727,7 +4843,7 @@ def remove_temp_filenames(temp_filenames_list):
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
- src_video, src_mask, src_ref_images = None, None, None
+ src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
@@ -4899,7 +5015,7 @@ def remove_temp_filenames(temp_filenames_list):
elif "R" in video_prompt_type: # sparse video to video
src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
- src_image, _, _ = calculate_dimensions_and_resize_image(src_image, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size)
+ src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size)
refresh_preview["video_guide"] = src_image
src_video = convert_image_to_tensor(src_image).unsqueeze(1)
if sample_fit_canvas != None:
@@ -4918,15 +5034,38 @@ def remove_temp_filenames(temp_filenames_list):
if pre_video_guide != None:
src_video = torch.cat( [pre_video_guide, src_video], dim=1)
elif image_guide is not None:
- image_guide, new_height, new_width = calculate_dimensions_and_resize_image(image_guide, height, width, sample_fit_canvas, fit_crop, block_size = block_size)
- image_size = (new_height, new_width)
- refresh_preview["image_guide"] = image_guide
- sample_fit_canvas = None
- if image_mask is not None:
- image_mask, _, _ = calculate_dimensions_and_resize_image(image_mask, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size)
- refresh_preview["image_mask"] = image_mask
+ new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims)
+ if sample_fit_canvas is not None:
+ image_size = (new_image_guide.size[1], new_image_guide.size[0])
+ sample_fit_canvas = None
+ refresh_preview["image_guide"] = new_image_guide
+ if new_image_mask is not None:
+ refresh_preview["image_mask"] = new_image_mask
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
+ if repeat_no == 1:
+ frames_positions_list = []
+ if frames_positions is not None and len(frames_positions)> 0:
+ positions = frames_positions.split(" ")
+ cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0
+ last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
+ joker_used = False
+ project_window_no = 1
+ for pos in positions :
+ if len(pos) > 0:
+ if pos in ["L", "l"]:
+ cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
+ if cur_end_pos >= last_frame_no and not joker_used:
+ joker_used = True
+ cur_end_pos = last_frame_no -1
+ project_window_no += 1
+ frames_positions_list.append(cur_end_pos)
+ cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
+ else:
+ frames_positions_list.append(int(pos)-1 + alignment_shift)
+ frames_positions_list = frames_positions_list[:len(image_refs)]
+ nb_frames_positions = len(frames_positions_list)
+
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
from shared.utils.utils import get_outpainting_full_area_dimensions
w, h = image_refs[0].size
@@ -4963,7 +5102,7 @@ def remove_temp_filenames(temp_filenames_list):
frames_to_inject[pos] = image_refs[i]
if vace :
- frames_to_inject_parsed = frames_to_inject[aligned_guide_start_frame: aligned_guide_end_frame]
+ frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame]
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
@@ -4971,14 +5110,13 @@ def remove_temp_filenames(temp_filenames_list):
[image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy],
current_video_length, image_size = image_size, device ="cpu",
keep_video_guide_frames=keep_frames_parsed,
- start_frame = aligned_guide_start_frame,
pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide],
inject_frames= frames_to_inject_parsed,
outpainting_dims = outpainting_dims,
any_background_ref = any_background_ref
)
if len(frames_to_inject_parsed) or any_background_ref:
- new_image_refs = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
+ new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
if any_background_ref:
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
else:
@@ -5094,8 +5232,9 @@ def set_header_text(txt):
pre_video_frame = pre_video_frame,
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
image_refs_relative_size = image_refs_relative_size,
- image_guide= image_guide,
- image_mask= image_mask,
+ image_guide= new_image_guide,
+ image_mask= new_image_mask,
+ outpainting_dims = outpainting_dims,
)
except Exception as e:
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
@@ -6755,10 +6894,17 @@ def refresh_audio_prompt_type_remux(state, audio_prompt_type, remux):
audio_prompt_type = add_to_sequence(audio_prompt_type, remux)
return audio_prompt_type
+def refresh_remove_background_sound(state, audio_prompt_type, remove_background_sound):
+ audio_prompt_type = del_in_sequence(audio_prompt_type, "V")
+ if remove_background_sound:
+ audio_prompt_type = add_to_sequence(audio_prompt_type, "V")
+ return audio_prompt_type
+
+
def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources):
audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPAB")
audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources)
- return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type))
+ return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)), gr.update(visible= any_letters(audio_prompt_type, "ABX"))
def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_radio):
image_prompt_type = del_in_sequence(image_prompt_type, "VLTS")
@@ -6775,7 +6921,7 @@ def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt
image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio)
return image_prompt_type, gr.update(visible = "E" in image_prompt_type )
-def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs):
+def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs, image_mode):
model_type = state["model_type"]
model_def = get_model_def(model_type)
image_ref_choices = model_def.get("image_ref_choices", None)
@@ -6785,11 +6931,10 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_
video_prompt_type = del_in_sequence(video_prompt_type, "KFI")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs)
visible = "I" in video_prompt_type
- vace= test_vace_module(state["model_type"])
-
+ any_outpainting= image_mode in model_def.get("video_guide_outpainting", [])
rm_bg_visible= visible and not model_def.get("no_background_removal", False)
img_rel_size_visible = visible and model_def.get("any_image_refs_relative_size", False)
- return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace )
+ return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and any_outpainting )
def switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
if image_mode == 0: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
@@ -6840,13 +6985,13 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
visible = "V" in video_prompt_type
model_type = state["model_type"]
model_def = get_model_def(model_type)
- vace= test_vace_module(model_type)
+ any_outpainting= image_mode in model_def.get("video_guide_outpainting", [])
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
image_outputs = image_mode > 0
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value )
- return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible)
+ return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible)
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode):
@@ -7112,6 +7257,12 @@ def refresh_video_length_label(state, current_video_length, force_fps, video_gui
computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source )
return gr.update(label= compute_video_length_label(computed_fps, current_video_length))
+def get_default_value(choices, current_value, default_value = None):
+ for label, value in choices:
+ if value == current_value:
+ return current_value
+ return default_value
+
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None):
global inputs_names #, advanced
@@ -7277,7 +7428,10 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "")
model_mode_choices = model_def.get("model_modes", None)
- with gr.Column(visible= image_mode_value == 0 and (len(image_prompt_types_allowed)> 0 or model_mode_choices is not None)) as image_prompt_column:
+ model_modes_visibility = [0,1,2]
+ if model_mode_choices is not None: model_modes_visibility= model_mode_choices.get("image_modes", model_modes_visibility)
+
+ with gr.Column(visible= image_mode_value == 0 and len(image_prompt_types_allowed)> 0 or model_mode_choices is not None and image_mode_value in model_modes_visibility ) as image_prompt_column:
# Video Continue / Start Frame / End Frame
image_prompt_type_value= ui_defaults.get("image_prompt_type","")
image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
@@ -7293,7 +7447,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
if "L" in image_prompt_types_allowed:
any_video_source = True
image_prompt_type_choices += [("Continue Last Video", "L")]
- with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group:
+ with gr.Group(visible= len(image_prompt_types_allowed)>1 and image_mode_value == 0) as image_prompt_type_group:
with gr.Row():
image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL")
image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1] if len(image_prompt_type_choices) > 0 else "")
@@ -7309,10 +7463,11 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new Videos in the Generation Queue", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value )
video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),)
image_end_row, image_end, image_end_extra = get_image_gallery(label= get_image_end_label(ui_defaults.get("multi_prompts_gen_type", 0)), value = ui_defaults.get("image_end", None), visible= any_letters(image_prompt_type_value, "SVL") and ("E" in image_prompt_type_value) )
- if model_mode_choices is None:
+ if model_mode_choices is None or image_mode_value not in model_modes_visibility:
model_mode = gr.Dropdown(value=None, visible=False)
else:
- model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=ui_defaults.get("model_mode", model_mode_choices["default"]), label=model_mode_choices["label"], visible=True)
+ model_mode_value = get_default_value(model_mode_choices["choices"], ui_defaults.get("model_mode", None), model_mode_choices["default"] )
+ model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=model_mode_value, label=model_mode_choices["label"], visible=True)
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
any_control_video = any_control_image = False
@@ -7374,9 +7529,11 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process")
if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image")
video_prompt_type_video_guide_alt_choices = [(label.replace("Video", "Image") if image_outputs else label, value) for label,value in guide_custom_choices["choices"] ]
+ guide_custom_choices_value = get_default_value(video_prompt_type_video_guide_alt_choices, filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"]), guide_custom_choices.get("default", "") )
video_prompt_type_video_guide_alt = gr.Dropdown(
choices= video_prompt_type_video_guide_alt_choices,
- value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ),
+ # value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ),
+ value=guide_custom_choices_value,
visible = guide_custom_choices.get("visible", True),
label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = 2
)
@@ -7437,7 +7594,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2
)
- image_guide = gr.Image(label= "Control Image", height = 800, width=800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
+ image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None))
if image_mode_value >= 1:
image_guide_value = ui_defaults.get("image_guide", None)
@@ -7457,7 +7614,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
layers=False,
brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"),
# fixed_canvas= True,
- width=800,
+ # width=800,
height=800,
# transforms=None,
# interactive=True,
@@ -7472,12 +7629,12 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label=f"Denoising Strength (the Lower the Closer to the Control {'Image' if image_outputs else 'Video'})", visible = "G" in video_prompt_type_value, show_reset_button= False)
keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False)
keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last
-
- with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col:
+ video_guide_outpainting_modes = model_def.get("video_guide_outpainting", [])
+ with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and image_mode_value in video_guide_outpainting_modes) as video_guide_outpainting_col:
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group():
- video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
+ video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@@ -7495,7 +7652,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "")
image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
- frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" )
+ frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None
@@ -7522,7 +7679,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
speaker_choices=[("None", "")]
if any_single_speaker: speaker_choices += [("One Person Speaking Only", "A")]
if any_multi_speakers:speaker_choices += [
- ("Two speakers, Auto Separation of Speakers (will work only if there is little background noise)", "XA"),
+ ("Two speakers, Auto Separation of Speakers (will work only if Voices are distinct)", "XA"),
("Two speakers, Speakers Audio sources are assumed to be played in a Row", "CAB"),
("Two speakers, Speakers Audio sources are assumed to be played in Parallel", "PAB")
]
@@ -7537,6 +7694,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
with gr.Row(visible = any_audio_voices_support and not image_outputs) as audio_guide_row:
audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= any_audio_voices_support and "A" in audio_prompt_type_value )
audio_guide2 = gr.Audio(value= ui_defaults.get("audio_guide2", None), type="filepath", label="Voice to follow #2", show_download_button= True, visible= any_audio_voices_support and "B" in audio_prompt_type_value )
+ remove_background_sound = gr.Checkbox(label="Video Motion ignores Background Music (to get a better LipSync)", value="V" in audio_prompt_type_value, visible = any_audio_voices_support and any_letters(audio_prompt_type_value, "ABX") and not image_outputs)
with gr.Row(visible = any_audio_voices_support and ("B" in audio_prompt_type_value or "X" in audio_prompt_type_value) and not image_outputs ) as speakers_locations_row:
speakers_locations = gr.Text( ui_defaults.get("speakers_locations", "0:45 55:100"), label="Speakers Locations separated by a Space. Each Location = Left:Right or a BBox Left:Top:Right:Bottom", visible= True)
@@ -7916,7 +8074,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
sliding_window_defaults = model_def.get("sliding_window_defaults", {})
sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size")
sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)")
- sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True)
+ sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",0), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True)
sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace)
sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True)
@@ -7926,7 +8084,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
("Aligned to the beginning of the First Window of the new Video Sample", "T"),
],
value=filter_letters(video_prompt_type_value, "T"),
- label="Control Video / Control Audio temporal alignment when any Source Video",
+ label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue",
visible = vace or ltxv or t2v or infinitetalk
)
@@ -8112,7 +8270,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn,
- NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs,
+ NAG_col, remove_background_sound , speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs,
min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn, tab_inpaint, tab_t2v] + image_start_extra + image_end_extra + image_refs_extra # presets_column,
if update_form:
locals_dict = locals()
@@ -8131,11 +8289,11 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last", show_progress="hidden" )
guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ])
audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type])
- audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row])
+ remove_background_sound.change(fn=refresh_remove_background_sound, inputs=[state, audio_prompt_type, remove_background_sound], outputs=[audio_prompt_type])
+ audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row, remove_background_sound])
image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" )
image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] )
- # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand])
- video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden")
+ video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden")
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand], show_progress="hidden")
video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength ], show_progress="hidden")
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden")
@@ -8338,7 +8496,7 @@ def activate_status(state):
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
- outputs= queue_df,
+ outputs= [queue_df, queue_accordion],
show_progress="hidden",
).then(fn=prepare_generate_video,
inputs= [state],
@@ -8346,11 +8504,6 @@ def activate_status(state):
).then(fn=activate_status,
inputs= [state],
outputs= [status_trigger],
- ).then(
- fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
- inputs=[state],
- outputs=[queue_accordion],
- show_progress="hidden",
).then(fn=process_tasks,
inputs= [state],
outputs= [preview_trigger, output_trigger],
@@ -8468,12 +8621,7 @@ def activate_status(state):
outputs= None
).then(fn=process_prompt_and_add_tasks,
inputs = [state, model_choice],
- outputs=queue_df,
- show_progress="hidden",
- ).then(
- fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(),
- inputs=[state],
- outputs=[queue_accordion],
+ outputs=[queue_df, queue_accordion],
show_progress="hidden",
).then(
fn=update_status,
From ae93d71217c080c73c17a7ee5fd04f5d5dc65d93 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 16 Sep 2025 22:38:45 +0200
Subject: [PATCH 027/155] fixed simple annoying bugs
---
README.md | 6 +++---
models/qwen/qwen_main.py | 1 +
models/wan/any2video.py | 2 +-
wgp.py | 2 +-
4 files changed, 6 insertions(+), 5 deletions(-)
diff --git a/README.md b/README.md
index 2411cea3d..8c5e4367a 100644
--- a/README.md
+++ b/README.md
@@ -24,12 +24,12 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
-- **First Frame / Last Frame for Vace** : Vace model are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias :
+- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias :
For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on ....
-If you *Continue a Video* , you just need *"L L L"* since the the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*.
+If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*.
-- **Qwen Inpainting** exist now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area.
+- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area.
- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video.
diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py
index da84e0d0e..cec0b8e05 100644
--- a/models/qwen/qwen_main.py
+++ b/models/qwen/qwen_main.py
@@ -217,6 +217,7 @@ def generate(
def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs):
if model_mode == 0: return [], []
preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
+ if len(preloadURLs) == 0: return [], []
return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index dde1a6535..3e85f7963 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -146,7 +146,7 @@ def __init__(
from mmgp.safetensors2 import torch_load_file
else:
if self.transformer_switch:
- if 0 in submodel_no_list[2:] and 1 in submodel_no_list:
+ if 0 in submodel_no_list[2:] and 1 in submodel_no_list[2:]:
raise Exception("Shared and non shared modules at the same time across multipe models is not supported")
if 0 in submodel_no_list[2:]:
diff --git a/wgp.py b/wgp.py
index de8e2f66f..ee26376b4 100644
--- a/wgp.py
+++ b/wgp.py
@@ -62,7 +62,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.6"
+WanGP_version = "8.61"
settings_version = 2.35
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
From fc615ffb3cb063bd4579ee98b2f81e7ca6633971 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 16 Sep 2025 23:01:54 +0200
Subject: [PATCH 028/155] fixed simple annoying bugs
---
wgp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/wgp.py b/wgp.py
index ee26376b4..ee6b5f62e 100644
--- a/wgp.py
+++ b/wgp.py
@@ -356,7 +356,7 @@ def ret():
outpainting_dims = get_outpainting_dims(video_guide_outpainting)
- if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None:
+ if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"):
gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting")
if len(loras_multipliers) > 0:
From 84010bd861545fd2d8365aaece6b4017d40da9aa Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Mon, 22 Sep 2025 17:11:25 +0200
Subject: [PATCH 029/155] commit in case there is an unrecoverable code
hemorragy
---
configs/animate.json | 15 +
configs/lucy_edit.json | 14 +
defaults/animate.json | 13 +
defaults/lucy_edit.json | 18 ++
defaults/ti2v_2_2_fastwan.json | 1 +
models/qwen/pipeline_qwenimage.py | 146 +++++----
models/wan/any2video.py | 193 +++++------
models/wan/modules/model.py | 53 ++-
models/wan/wan_handler.py | 81 +++--
shared/convert/convert_diffusers_to_flux.py | 342 ++++++++++++++++++++
shared/inpainting/__init__.py | 0
shared/inpainting/lanpaint.py | 240 ++++++++++++++
shared/inpainting/utils.py | 301 +++++++++++++++++
shared/utils/audio_video.py | 3 +
shared/utils/download.py | 110 +++++++
shared/utils/utils.py | 246 ++++++--------
wgp.py | 230 +++++++++----
17 files changed, 1623 insertions(+), 383 deletions(-)
create mode 100644 configs/animate.json
create mode 100644 configs/lucy_edit.json
create mode 100644 defaults/animate.json
create mode 100644 defaults/lucy_edit.json
create mode 100644 shared/convert/convert_diffusers_to_flux.py
create mode 100644 shared/inpainting/__init__.py
create mode 100644 shared/inpainting/lanpaint.py
create mode 100644 shared/inpainting/utils.py
create mode 100644 shared/utils/download.py
diff --git a/configs/animate.json b/configs/animate.json
new file mode 100644
index 000000000..7e98ca95e
--- /dev/null
+++ b/configs/animate.json
@@ -0,0 +1,15 @@
+{
+ "_class_name": "WanModel",
+ "_diffusers_version": "0.30.0",
+ "dim": 5120,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_dim": 36,
+ "model_type": "i2v",
+ "num_heads": 40,
+ "num_layers": 40,
+ "out_dim": 16,
+ "text_len": 512,
+ "motion_encoder_dim": 512
+}
\ No newline at end of file
diff --git a/configs/lucy_edit.json b/configs/lucy_edit.json
new file mode 100644
index 000000000..4983cedd1
--- /dev/null
+++ b/configs/lucy_edit.json
@@ -0,0 +1,14 @@
+{
+ "_class_name": "WanModel",
+ "_diffusers_version": "0.33.0",
+ "dim": 3072,
+ "eps": 1e-06,
+ "ffn_dim": 14336,
+ "freq_dim": 256,
+ "in_dim": 96,
+ "model_type": "ti2v2_2",
+ "num_heads": 24,
+ "num_layers": 30,
+ "out_dim": 48,
+ "text_len": 512
+}
diff --git a/defaults/animate.json b/defaults/animate.json
new file mode 100644
index 000000000..fee26d76a
--- /dev/null
+++ b/defaults/animate.json
@@ -0,0 +1,13 @@
+{
+ "model": {
+ "name": "Wan2.2 Animate",
+ "architecture": "animate",
+ "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.",
+ "URLs": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors"
+ ],
+ "group": "wan2_2"
+ }
+}
\ No newline at end of file
diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json
new file mode 100644
index 000000000..6344dfffd
--- /dev/null
+++ b/defaults/lucy_edit.json
@@ -0,0 +1,18 @@
+{
+ "model": {
+ "name": "Wan2.2 Lucy Edit 5B",
+ "architecture": "lucy_edit",
+ "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
+ "URLs": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors"
+ ],
+ "group": "wan2_2"
+ },
+ "video_length": 81,
+ "guidance_scale": 5,
+ "flow_shift": 5,
+ "num_inference_steps": 30,
+ "resolution": "1280x720"
+}
\ No newline at end of file
diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json
index 064c2b4ba..fa69f82ec 100644
--- a/defaults/ti2v_2_2_fastwan.json
+++ b/defaults/ti2v_2_2_fastwan.json
@@ -7,6 +7,7 @@
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
"group": "wan2_2"
},
+ "prompt" : "Put the person into a clown outfit.",
"video_length": 121,
"guidance_scale": 1,
"flow_shift": 3,
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index be982aabc..14728862d 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from mmgp import offload
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
@@ -387,7 +386,8 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod
- def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ def _pack_latents(latents):
+ batch_size, num_channels_latents, _, height, width = latents.shape
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
@@ -479,7 +479,7 @@ def prepare_latents(
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
- shape = (batch_size, 1, num_channels_latents, height, width)
+ shape = (batch_size, num_channels_latents, 1, height, width)
image_latents = None
if image is not None:
@@ -499,10 +499,7 @@ def prepare_latents(
else:
image_latents = torch.cat([image_latents], dim=0)
- image_latent_height, image_latent_width = image_latents.shape[3:]
- image_latents = self._pack_latents(
- image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
- )
+ image_latents = self._pack_latents(image_latents)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -511,7 +508,7 @@ def prepare_latents(
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents)
else:
latents = latents.to(device=device, dtype=dtype)
@@ -713,11 +710,12 @@ def __call__(
image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
# image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
height, width = image_height, image_width
- image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS))
+ image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS))
image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
- image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0)
+ image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
# convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
- image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device)
+ image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
+ image_mask_latents = self._pack_latents(image_mask_latents)
prompt_image = image
if image.size != (image_width, image_height):
@@ -822,6 +820,7 @@ def __call__(
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
)
morph, first_step = False, 0
+ lanpaint_proc = None
if image_mask_latents is not None:
randn = torch.randn_like(original_image_latents)
if denoising_strength < 1.:
@@ -833,7 +832,8 @@ def __call__(
timesteps = timesteps[first_step:]
self.scheduler.timesteps = timesteps
self.scheduler.sigmas= self.scheduler.sigmas[first_step:]
-
+ # from shared.inpainting.lanpaint import LanPaint
+ # lanpaint_proc = LanPaint()
# 6. Denoising loop
self.scheduler.set_begin_index(0)
updated_num_steps= len(timesteps)
@@ -847,48 +847,52 @@ def __call__(
offload.set_step_no_for_lora(self.transformer, first_step + i)
if self.interrupt:
continue
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph:
latent_noise_factor = t/1000
latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor
- self._current_timestep = t
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
- timestep = t.expand(latents.shape[0]).to(latents.dtype)
- latent_model_input = latents
- if image_latents is not None:
- latent_model_input = torch.cat([latents, image_latents], dim=1)
-
- if do_true_cfg and joint_pass:
- noise_pred, neg_noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep / 1000,
- guidance=guidance,
- encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
- encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
- img_shapes=img_shapes,
- txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
- attention_kwargs=self.attention_kwargs,
- **kwargs
- )
- if noise_pred == None: return None
- noise_pred = noise_pred[:, : latents.size(1)]
- neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
- else:
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep / 1000,
- guidance=guidance,
- encoder_hidden_states_mask_list=[prompt_embeds_mask],
- encoder_hidden_states_list=[prompt_embeds],
- img_shapes=img_shapes,
- txt_seq_lens_list=[txt_seq_lens],
- attention_kwargs=self.attention_kwargs,
- **kwargs
- )[0]
- if noise_pred == None: return None
- noise_pred = noise_pred[:, : latents.size(1)]
+ latents_dtype = latents.dtype
+
+ # latent_model_input = latents
+ def denoise(latent_model_input, true_cfg_scale):
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ do_true_cfg = true_cfg_scale > 1
+ if do_true_cfg and joint_pass:
+ noise_pred, neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance, #!!!!
+ encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask],
+ encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds],
+ img_shapes=img_shapes,
+ txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens],
+ attention_kwargs=self.attention_kwargs,
+ **kwargs
+ )
+ if noise_pred == None: return None, None
+ noise_pred = noise_pred[:, : latents.size(1)]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ else:
+ neg_noise_pred = None
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask_list=[prompt_embeds_mask],
+ encoder_hidden_states_list=[prompt_embeds],
+ img_shapes=img_shapes,
+ txt_seq_lens_list=[txt_seq_lens],
+ attention_kwargs=self.attention_kwargs,
+ **kwargs
+ )[0]
+ if noise_pred == None: return None, None
+ noise_pred = noise_pred[:, : latents.size(1)]
if do_true_cfg:
neg_noise_pred = self.transformer(
@@ -902,27 +906,43 @@ def __call__(
attention_kwargs=self.attention_kwargs,
**kwargs
)[0]
- if neg_noise_pred == None: return None
+ if neg_noise_pred == None: return None, None
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ return noise_pred, neg_noise_pred
+ def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
+ if do_true_cfg:
+ comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred)
+ if comb_pred == None: return None
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
- if do_true_cfg:
- comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
- if comb_pred == None: return None
+ return noise_pred
- cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
- noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
- noise_pred = comb_pred * (cond_norm / noise_norm)
- neg_noise_pred = None
+
+ if lanpaint_proc is not None and i<=3:
+ latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8)
+ if latents is None: return None
+
+ noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
+ if noise_pred == None: return None
+ noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
+ neg_noise_pred = None
# compute the previous noisy sample x_t -> x_t-1
- latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+ noise_pred = None
+
if image_mask_latents is not None:
- next_t = timesteps[i+1] if i0:
+ msk[:, :nb_frames_unchanged] = 1
+ msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
+ msk = msk.transpose(1,2)[0]
+ return msk
+
def generate(self,
input_prompt,
input_frames= None,
input_masks = None,
- input_ref_images = None,
+ input_ref_images = None,
+ input_ref_masks = None,
+ input_faces = None,
input_video = None,
image_start = None,
image_end = None,
@@ -541,14 +506,18 @@ def generate(self,
infinitetalk = model_type in ["infinitetalk"]
standin = model_type in ["standin", "vace_standin_14B"]
recam = model_type in ["recam_1.3B"]
- ti2v = model_type in ["ti2v_2_2"]
+ ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
+ lucy_edit= model_type in ["lucy_edit"]
+ animate= model_type in ["animate"]
start_step_no = 0
ref_images_count = 0
trim_frames = 0
- extended_overlapped_latents = None
+ extended_overlapped_latents = clip_image_start = clip_image_end = None
no_noise_latents_injection = infinitetalk
timestep_injection = False
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
+ extended_input_dim = 0
+ ref_images_before = False
# image2video
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False
@@ -598,17 +567,7 @@ def generate(self,
if image_end is not None:
img_end_frame = image_end.unsqueeze(1).to(self.device)
-
- if hasattr(self, "clip"):
- clip_image_size = self.clip.model.image_size
- image_start = resize_lanczos(image_start, clip_image_size, clip_image_size)
- image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start
- if model_type == "flf2v_720p":
- clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]])
- else:
- clip_context = self.clip.visual([image_start[:, None, :, :]])
- else:
- clip_context = None
+ clip_image_start, clip_image_end = image_start, image_end
if any_end_frame:
enc= torch.concat([
@@ -647,21 +606,62 @@ def generate(self,
if infinitetalk:
lat_y = self.vae.encode([input_video], VAE_tile_size)[0]
extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0)
- # if control_pre_frames_count != pre_frames_count:
lat_y = input_video = None
kwargs.update({ 'y': y})
- if not clip_context is None:
- kwargs.update({'clip_fea': clip_context})
- # Recam Master
- if recam:
- target_camera = model_mode
- height,width = input_frames.shape[-2:]
- input_frames = input_frames.to(dtype=self.dtype , device=self.device)
- source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
+ # Animate
+ if animate:
+ pose_pixels = input_frames * input_masks
+ input_masks = 1. - input_masks
+ pose_pixels -= input_masks
+ save_video(pose_pixels, "pose.mp4")
+ pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
+ input_frames = input_frames * input_masks
+ if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
+ if prefix_frames_count > 0:
+ input_frames[:, :prefix_frames_count] = input_video
+ input_masks[:, :prefix_frames_count] = 1
+ save_video(input_frames, "input_frames.mp4")
+ save_video(input_masks, "input_masks.mp4", value_range=(0,1))
+ lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
+ msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
+ msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
+ msk = torch.concat([msk_ref, msk_control], dim=1)
+ clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
+ lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
+ y = torch.concat([msk, lat_y])
+ kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
+ lat_y = msk = msk_control = msk_ref = pose_pixels = None
+ ref_images_before = True
+ ref_images_count = 1
+ lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1
+
+ # Clip image
+ if hasattr(self, "clip") and clip_image_start is not None:
+ clip_image_size = self.clip.model.image_size
+ clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size)
+ clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start
+ if model_type == "flf2v_720p":
+ clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]])
+ else:
+ clip_context = self.clip.visual([clip_image_start[:, None, :, :]])
+ clip_image_start = clip_image_end = None
+ kwargs.update({'clip_fea': clip_context})
+
+ # Recam Master & Lucy Edit
+ if recam or lucy_edit:
+ frame_num, height,width = input_frames.shape[-3:]
+ lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
+ frame_num = (lat_frames -1) * self.vae_stride[0] + 1
+ input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device)
+ extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device)
+ extended_input_dim = 2 if recam else 1
del input_frames
+
+ if recam:
# Process target camera (recammaster)
+ target_camera = model_mode
from shared.utils.cammmaster_tools import get_camera_embedding
cam_emb = get_camera_embedding(target_camera)
cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
@@ -715,6 +715,8 @@ def generate(self,
height, width = input_video.shape[-2:]
source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0)
timestep_injection = True
+ if extended_input_dim > 0:
+ extended_latents[:, :, :source_latents.shape[2]] = source_latents
# Vace
if vace :
@@ -722,6 +724,7 @@ def generate(self,
input_frames = [u.to(self.device) for u in input_frames]
input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
input_masks = [u.to(self.device) for u in input_masks]
+ ref_images_before = True
if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
@@ -771,9 +774,9 @@ def generate(self,
expand_shape = [batch_size] + [-1] * len(target_shape)
# Ropes
- if target_camera != None:
+ if extended_input_dim>=2:
shape = list(target_shape[1:])
- shape[0] *= 2
+ shape[extended_input_dim-2] *= 2
freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
else:
freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx)
@@ -901,8 +904,8 @@ def clear():
for zz in z:
zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor
- if target_camera != None:
- latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2)
+ if extended_input_dim > 0:
+ latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim)
else:
latent_model_input = latents
@@ -1030,7 +1033,7 @@ def clear():
if callback is not None:
latents_preview = latents
- if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
+ if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames]
if image_outputs: latents_preview= latents_preview[:, :,:1]
if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2)
@@ -1041,7 +1044,7 @@ def clear():
if timestep_injection:
latents[:, :, :source_latents.shape[2]] = source_latents
- if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
+ if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:]
if trim_frames > 0: latents= latents[:, :,:-trim_frames]
if return_latent_slice != None:
latent_slice = latents[:, :, return_latent_slice].clone()
@@ -1078,4 +1081,12 @@ def adapt_vace_model(self, model):
delattr(model, "vace_blocks")
+ def adapt_animate_model(self, model):
+ modules_dict= { k: m for k, m in model.named_modules()}
+ for animate_layer in range(8):
+ module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"]
+ model_layer = animate_layer * 5
+ target = modules_dict[f"blocks.{model_layer}"]
+ setattr(target, "face_adapter_fuser_blocks", module )
+ delattr(model, "face_adapter")
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index d3dc783e3..c98087b5d 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -16,6 +16,9 @@
from shared.attention import pay_attention
from torch.backends.cuda import sdp_kernel
from ..multitalk.multitalk_utils import get_attn_map_with_target
+from ..animate.motion_encoder import Generator
+from ..animate.face_blocks import FaceAdapter, FaceEncoder
+from ..animate.model_animate import after_patch_embedding
__all__ = ['WanModel']
@@ -499,6 +502,7 @@ def forward(
multitalk_masks=None,
ref_images_count=0,
standin_phase=-1,
+ motion_vec = None,
):
r"""
Args:
@@ -616,6 +620,10 @@ def forward(
x.add_(hint)
else:
x.add_(hint, alpha= scale)
+
+ if motion_vec is not None and self.block_no % 5 == 0:
+ x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False)
+
return x
class AudioProjModel(ModelMixin, ConfigMixin):
@@ -898,6 +906,7 @@ def __init__(self,
norm_input_visual=True,
norm_output_audio=True,
standin= False,
+ motion_encoder_dim=0,
):
super().__init__()
@@ -922,14 +931,15 @@ def __init__(self,
self.flag_causal_attention = False
self.block_mask = None
self.inject_sample_info = inject_sample_info
-
+ self.motion_encoder_dim = motion_encoder_dim
self.norm_output_audio = norm_output_audio
self.audio_window = audio_window
self.intermediate_dim = intermediate_dim
self.vae_scale = vae_scale
multitalk = multitalk_output_dim > 0
- self.multitalk = multitalk
+ self.multitalk = multitalk
+ animate = motion_encoder_dim > 0
# embeddings
self.patch_embedding = nn.Conv3d(
@@ -1027,6 +1037,25 @@ def __init__(self,
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
+ if animate:
+ self.pose_patch_embedding = nn.Conv3d(
+ 16, dim, kernel_size=patch_size, stride=patch_size
+ )
+
+ self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
+ self.face_adapter = FaceAdapter(
+ heads_num=self.num_heads,
+ hidden_dim=self.dim,
+ num_adapter_layers=self.num_layers // 5,
+ )
+
+ self.face_encoder = FaceEncoder(
+ in_dim=motion_encoder_dim,
+ hidden_dim=self.dim,
+ num_heads=4,
+ )
+
+
def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32):
layer_list = [self.head, self.head.head, self.patch_embedding]
target_dype= dtype
@@ -1208,6 +1237,9 @@ def forward(
ref_images_count = 0,
standin_freqs = None,
standin_ref = None,
+ pose_latents=None,
+ face_pixel_values=None,
+
):
# patch_dtype = self.patch_embedding.weight.dtype
modulation_dtype = self.time_projection[1].weight.dtype
@@ -1240,9 +1272,18 @@ def forward(
if bz > 1: y = y.expand(bz, -1, -1, -1, -1)
x = torch.cat([x, y], dim=1)
# embeddings
- # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype)
x = self.patch_embedding(x).to(modulation_dtype)
grid_sizes = x.shape[2:]
+ x_list[i] = x
+ y = None
+
+ motion_vec_list = []
+ for i, x in enumerate(x_list):
+ # animate embeddings
+ motion_vec = None
+ if pose_latents is not None:
+ x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values)
+ motion_vec_list.append(motion_vec)
if chipmunk:
x = x.unsqueeze(-1)
x_og_shape = x.shape
@@ -1250,7 +1291,7 @@ def forward(
else:
x = x.flatten(2).transpose(1, 2)
x_list[i] = x
- x, y = None, None
+ x = None
block_mask = None
@@ -1450,9 +1491,9 @@ def forward(
continue
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
else:
- for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)):
+ for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)):
if should_calc:
- x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs)
+ x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs)
del x
context = hints = audio_embedding = None
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 6586176bd..ff1570f3a 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -3,10 +3,10 @@
import gradio as gr
def test_class_i2v(base_model_type):
- return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ]
+ return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ]
def text_oneframe_overlap(base_model_type):
- return test_class_i2v(base_model_type) and not test_multitalk(base_model_type)
+ return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type)
def test_class_1_3B(base_model_type):
return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"]
@@ -17,6 +17,8 @@ def test_multitalk(base_model_type):
def test_standin(base_model_type):
return base_model_type in ["standin", "vace_standin_14B"]
+def test_wan_5B(base_model_type):
+ return base_model_type in ["ti2v_2_2", "lucy_edit"]
class family_handler():
@staticmethod
@@ -36,7 +38,7 @@ def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_st
def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181]
elif base_model_type in ["i2v_2_2"]:
def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902]
- elif base_model_type in ["ti2v_2_2"]:
+ elif test_wan_5B(base_model_type):
if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v
def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015]
else: # i2v
@@ -83,11 +85,13 @@ def query_model_def(base_model_type, model_def):
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
extra_model_def["vace_class"] = vace_class
- if test_multitalk(base_model_type):
+ if base_model_type in ["animate"]:
+ fps = 30
+ elif test_multitalk(base_model_type):
fps = 25
elif base_model_type in ["fantasy"]:
fps = 23
- elif base_model_type in ["ti2v_2_2"]:
+ elif test_wan_5B(base_model_type):
fps = 24
else:
fps = 16
@@ -100,14 +104,14 @@ def query_model_def(base_model_type, model_def):
extra_model_def.update({
"frames_minimum" : frames_minimum,
"frames_steps" : frames_steps,
- "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2",
+ "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2",
"multiple_submodels" : multiple_submodels,
"guidance_max_phases" : 3,
"skip_layer_guidance" : True,
"cfg_zero" : True,
"cfg_star" : True,
"adaptive_projected_guidance" : True,
- "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels),
+ "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
"mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
"convert_image_guide_to_video" : True,
@@ -146,6 +150,34 @@ def query_model_def(base_model_type, model_def):
}
# extra_model_def["at_least_one_image_ref_needed"] = True
+ if base_model_type in ["lucy_edit"]:
+ extra_model_def["keep_frames_video_guide_not_supported"] = True
+ extra_model_def["guide_preprocessing"] = {
+ "selection": ["UV"],
+ "labels" : { "UV": "Control Video"},
+ "visible": False,
+ }
+
+ if base_model_type in ["animate"]:
+ extra_model_def["guide_custom_choices"] = {
+ "choices":[
+ ("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"),
+ ("Replace Person in Control Video Person in Reference Image", "PVBAI"),
+ ],
+ "default": "KI",
+ "letters_filter": "PVBXAKI",
+ "label": "Type of Process",
+ "show_label" : False,
+ }
+ extra_model_def["video_guide_outpainting"] = [0,1]
+ extra_model_def["keep_frames_video_guide_not_supported"] = True
+ extra_model_def["extract_guide_from_window_start"] = True
+ extra_model_def["forced_guide_mask_inputs"] = True
+ extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
+ extra_model_def["background_ref_outpainted"] = False
+
+
+
if vace_class:
extra_model_def["guide_preprocessing"] = {
"selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"],
@@ -157,16 +189,19 @@ def query_model_def(base_model_type, model_def):
extra_model_def["image_ref_choices"] = {
"choices": [("None", ""),
- ("Inject only People / Objects", "I"),
- ("Inject Landscape and then People / Objects", "KI"),
- ("Inject Frames and then People / Objects", "FI"),
+ ("People / Objects", "I"),
+ ("Landscape followed by People / Objects (if any)", "KI"),
+ ("Positioned Frames followed by People / Objects (if any)", "FI"),
],
"letters_filter": "KFI",
}
extra_model_def["lock_image_refs_ratios"] = True
- extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames"
+ extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
extra_model_def["video_guide_outpainting"] = [0,1]
+ extra_model_def["pad_guide_video"] = True
+ extra_model_def["guide_inpaint_color"] = 127.5
+ extra_model_def["forced_guide_mask_inputs"] = True
if base_model_type in ["standin"]:
extra_model_def["lock_image_refs_ratios"] = True
@@ -209,10 +244,12 @@ def query_model_def(base_model_type, model_def):
"visible" : False,
}
- if vace_class or base_model_type in ["infinitetalk"]:
+ if vace_class or base_model_type in ["infinitetalk", "animate"]:
image_prompt_types_allowed = "TVL"
elif base_model_type in ["ti2v_2_2"]:
image_prompt_types_allowed = "TSVL"
+ elif base_model_type in ["lucy_edit"]:
+ image_prompt_types_allowed = "TVL"
elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]:
image_prompt_types_allowed = "SVL"
elif i2v:
@@ -234,8 +271,8 @@ def query_model_def(base_model_type, model_def):
def query_supported_types():
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
"t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
- "recam_1.3B",
- "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
+ "recam_1.3B", "animate",
+ "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@staticmethod
@@ -265,11 +302,12 @@ def query_family_infos():
@staticmethod
def get_vae_block_size(base_model_type):
- return 32 if base_model_type == "ti2v_2_2" else 16
+ return 32 if test_wan_5B(base_model_type) else 16
@staticmethod
def get_rgb_factors(base_model_type ):
from shared.RGB_factors import get_rgb_factors
+ if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2"
latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type)
return latent_rgb_factors, latent_rgb_factors_bias
@@ -283,7 +321,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder
"fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ]
}]
- if base_model_type == "ti2v_2_2":
+ if test_wan_5B(base_model_type):
download_def += [ {
"repoId" : "DeepBeepMeep/Wan2.2",
"sourceFolderList" : [""],
@@ -377,8 +415,8 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"sample_solver": "unipc",
})
- if test_class_i2v(base_model_type):
- ui_defaults["image_prompt_type"] = "S"
+ if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]:
+ ui_defaults["image_prompt_type"] = "S"
if base_model_type in ["fantasy"]:
ui_defaults.update({
@@ -434,10 +472,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"image_prompt_type": "T",
})
- if base_model_type in ["recam_1.3B"]:
+ if base_model_type in ["recam_1.3B", "lucy_edit"]:
ui_defaults.update({
"video_prompt_type": "UV",
})
+ elif base_model_type in ["animate"]:
+ ui_defaults.update({
+ "video_prompt_type": "PVBXAKI",
+ "mask_expand": 20,
+ })
if text_oneframe_overlap(base_model_type):
ui_defaults["sliding_window_overlap"] = 1
diff --git a/shared/convert/convert_diffusers_to_flux.py b/shared/convert/convert_diffusers_to_flux.py
new file mode 100644
index 000000000..608b1765e
--- /dev/null
+++ b/shared/convert/convert_diffusers_to_flux.py
@@ -0,0 +1,342 @@
+#!/usr/bin/env python3
+"""
+Convert a Flux model from Diffusers (folder or single-file) into the original
+single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI.
+
+Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file)
+Output : /path/to/flux1-your-model.safetensors (transformer only)
+
+Usage:
+ python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors
+ python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors
+ # optional quantization:
+ # --fp8 (float8_e4m3fn, simple)
+ # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors)
+"""
+
+import argparse
+import json
+from pathlib import Path
+from collections import OrderedDict
+
+import torch
+from safetensors import safe_open
+import safetensors.torch
+from tqdm import tqdm
+
+
+def parse_args():
+ ap = argparse.ArgumentParser()
+ ap.add_argument("diffusers_path", type=str,
+ help="Path to Diffusers checkpoint folder OR a single .safetensors file.")
+ ap.add_argument("output_path", type=str,
+ help="Output .safetensors path for the Flux transformer.")
+ ap.add_argument("--fp8", action="store_true",
+ help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).")
+ ap.add_argument("--fp8-scaled", action="store_true",
+ help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.")
+ return ap.parse_args()
+
+
+# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable).
+DIFFUSERS_MAP = {
+ # global embeds
+ "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"],
+ "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"],
+ "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"],
+ "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"],
+
+ "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"],
+ "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"],
+ "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"],
+ "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"],
+
+ "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"],
+ "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"],
+ "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"],
+ "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"],
+
+ "txt_in.weight": ["context_embedder.weight"],
+ "txt_in.bias": ["context_embedder.bias"],
+ "img_in.weight": ["x_embedder.weight"],
+ "img_in.bias": ["x_embedder.bias"],
+
+ # dual-stream (image/text) blocks
+ "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"],
+ "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"],
+ "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"],
+ "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"],
+
+ "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"],
+ "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"],
+ "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"],
+ "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"],
+
+ "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"],
+ "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"],
+ "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"],
+ "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"],
+
+ "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"],
+ "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"],
+ "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"],
+ "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"],
+
+ "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"],
+ "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"],
+ "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"],
+ "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"],
+
+ "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"],
+ "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"],
+ "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"],
+ "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"],
+
+ # single-stream blocks
+ "single_blocks.().modulation.lin.weight": ["norm.linear.weight"],
+ "single_blocks.().modulation.lin.bias": ["norm.linear.bias"],
+ "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"],
+ "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"],
+ "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"],
+ "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"],
+ "single_blocks.().linear2.weight": ["proj_out.weight"],
+ "single_blocks.().linear2.bias": ["proj_out.bias"],
+
+ # final
+ "final_layer.linear.weight": ["proj_out.weight"],
+ "final_layer.linear.bias": ["proj_out.bias"],
+ # these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift]
+ "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"],
+ "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"],
+}
+
+
+class DiffusersSource:
+ """
+ Uniform interface over:
+ 1) Folder with index JSON + shards
+ 2) Folder with exactly one .safetensors (no index)
+ 3) Single .safetensors file
+ Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning)
+ """
+
+ POSSIBLE_PREFIXES = ["", "model."] # try in this order
+
+ def __init__(self, path: Path):
+ p = Path(path)
+ if p.is_dir():
+ # use 'transformer' subfolder if present
+ if (p / "transformer").is_dir():
+ p = p / "transformer"
+ self._init_from_dir(p)
+ elif p.is_file() and p.suffix == ".safetensors":
+ self._init_from_single_file(p)
+ else:
+ raise FileNotFoundError(f"Invalid path: {p}")
+
+ # ---------- common helpers ----------
+
+ @staticmethod
+ def _strip_prefix(k: str) -> str:
+ return k[6:] if k.startswith("model.") else k
+
+ def _resolve(self, want: str):
+ """
+ Return the actual stored key matching `want` by trying known prefixes.
+ """
+ for pref in self.POSSIBLE_PREFIXES:
+ k = pref + want
+ if k in self._all_keys:
+ return k
+ return None
+
+ def has(self, want: str) -> bool:
+ return self._resolve(want) is not None
+
+ def get(self, want: str) -> torch.Tensor:
+ real_key = self._resolve(want)
+ if real_key is None:
+ raise KeyError(f"Missing key: {want}")
+ return self._get_by_real_key(real_key).to("cpu")
+
+ @property
+ def base_keys(self):
+ # keys without 'model.' prefix for scanning
+ return [self._strip_prefix(k) for k in self._all_keys]
+
+ # ---------- modes ----------
+
+ def _init_from_single_file(self, file_path: Path):
+ self._mode = "single"
+ self._file = file_path
+ self._handle = safe_open(file_path, framework="pt", device="cpu")
+ self._all_keys = list(self._handle.keys())
+
+ def _get_by_real_key(real_key: str):
+ return self._handle.get_tensor(real_key)
+
+ self._get_by_real_key = _get_by_real_key
+
+ def _init_from_dir(self, dpath: Path):
+ index_json = dpath / "diffusion_pytorch_model.safetensors.index.json"
+ if index_json.exists():
+ with open(index_json, "r", encoding="utf-8") as f:
+ index = json.load(f)
+ weight_map = index["weight_map"] # full mapping
+ self._mode = "sharded"
+ self._dpath = dpath
+ self._weight_map = {k: dpath / v for k, v in weight_map.items()}
+ self._all_keys = list(self._weight_map.keys())
+ self._open_handles = {}
+
+ def _get_by_real_key(real_key: str):
+ fpath = self._weight_map[real_key]
+ h = self._open_handles.get(fpath)
+ if h is None:
+ h = safe_open(fpath, framework="pt", device="cpu")
+ self._open_handles[fpath] = h
+ return h.get_tensor(real_key)
+
+ self._get_by_real_key = _get_by_real_key
+ return
+
+ # no index: try exactly one safetensors in folder
+ files = sorted(dpath.glob("*.safetensors"))
+ if len(files) != 1:
+ raise FileNotFoundError(
+ f"No index found and {dpath} does not contain exactly one .safetensors file."
+ )
+ self._init_from_single_file(files[0])
+
+
+def main():
+ args = parse_args()
+ src = DiffusersSource(Path(args.diffusers_path))
+
+ # Count blocks by scanning base keys (with any 'model.' prefix removed)
+ num_dual = 0
+ num_single = 0
+ for k in src.base_keys:
+ if k.startswith("transformer_blocks."):
+ try:
+ i = int(k.split(".")[1])
+ num_dual = max(num_dual, i + 1)
+ except Exception:
+ pass
+ elif k.startswith("single_transformer_blocks."):
+ try:
+ i = int(k.split(".")[1])
+ num_single = max(num_single, i + 1)
+ except Exception:
+ pass
+ print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks")
+
+ # Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0)
+ def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor:
+ shift, scale = vec.chunk(2, dim=0)
+ return torch.cat([scale, shift], dim=0)
+
+ orig = {}
+
+ # Per-block (dual)
+ for b in range(num_dual):
+ prefix = f"transformer_blocks.{b}."
+ for okey, dvals in DIFFUSERS_MAP.items():
+ if not okey.startswith("double_blocks."):
+ continue
+ dkeys = [prefix + v for v in dvals]
+ if not all(src.has(k) for k in dkeys):
+ continue
+ if len(dkeys) == 1:
+ orig[okey.replace("()", str(b))] = src.get(dkeys[0])
+ else:
+ orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
+
+ # Per-block (single)
+ for b in range(num_single):
+ prefix = f"single_transformer_blocks.{b}."
+ for okey, dvals in DIFFUSERS_MAP.items():
+ if not okey.startswith("single_blocks."):
+ continue
+ dkeys = [prefix + v for v in dvals]
+ if not all(src.has(k) for k in dkeys):
+ continue
+ if len(dkeys) == 1:
+ orig[okey.replace("()", str(b))] = src.get(dkeys[0])
+ else:
+ orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0)
+
+ # Globals (non-block)
+ for okey, dvals in DIFFUSERS_MAP.items():
+ if okey.startswith(("double_blocks.", "single_blocks.")):
+ continue
+ dkeys = dvals
+ if not all(src.has(k) for k in dkeys):
+ continue
+ if len(dkeys) == 1:
+ orig[okey] = src.get(dkeys[0])
+ else:
+ orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0)
+
+ # Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves
+ if "final_layer.adaLN_modulation.1.weight" in orig:
+ orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift(
+ orig["final_layer.adaLN_modulation.1.weight"]
+ )
+ if "final_layer.adaLN_modulation.1.bias" in orig:
+ orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift(
+ orig["final_layer.adaLN_modulation.1.bias"]
+ )
+
+ # Optional FP8 variants (experimental; not required for ComfyUI/BFL)
+ if args.fp8 or args.fp8_scaled:
+ dtype = torch.float8_e4m3fn # noqa
+ minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max
+
+ def stochastic_round_to(t):
+ t = t.float().clamp(minv, maxv)
+ lower = torch.floor(t * 256) / 256
+ upper = torch.ceil(t * 256) / 256
+ prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t))
+ rnd = torch.rand_like(t)
+ out = torch.where(rnd < prob, upper, lower)
+ return out.to(dtype)
+
+ def scale_to_8bit(weight, target_max=416.0):
+ absmax = weight.abs().max()
+ scale = absmax / target_max if absmax > 0 else torch.tensor(1.0)
+ scaled = (weight / scale).clamp(minv, maxv).to(dtype)
+ return scaled, scale
+
+ scales = {}
+ for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"):
+ t = orig[k]
+ if args.fp8:
+ orig[k] = stochastic_round_to(t)
+ else:
+ if k.endswith(".weight") and t.dim() == 2:
+ qt, s = scale_to_8bit(t)
+ orig[k] = qt
+ scales[k[:-len(".weight")] + ".scale_weight"] = s
+ else:
+ orig[k] = t.clamp(minv, maxv).to(dtype)
+ if args.fp8_scaled:
+ orig.update(scales)
+ orig["scaled_fp8"] = torch.tensor([], dtype=dtype)
+ else:
+ # Default: save in bfloat16
+ for k in list(orig.keys()):
+ orig[k] = orig[k].to(torch.bfloat16).cpu()
+
+ out_path = Path(args.output_path)
+ out_path.parent.mkdir(parents=True, exist_ok=True)
+ meta = OrderedDict()
+ meta["format"] = "pt"
+ meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d")
+ print(f"Saving transformer to: {out_path}")
+ safetensors.torch.save_file(orig, str(out_path), metadata=meta)
+ print("Done.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/shared/inpainting/__init__.py b/shared/inpainting/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/shared/inpainting/lanpaint.py b/shared/inpainting/lanpaint.py
new file mode 100644
index 000000000..3165e7b01
--- /dev/null
+++ b/shared/inpainting/lanpaint.py
@@ -0,0 +1,240 @@
+import torch
+from .utils import *
+from functools import partial
+
+# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/)
+
+def _pack_latents(latents):
+ batch_size, num_channels_latents, _, height, width = latents.shape
+
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+def _unpack_latents(latents, height, width, vae_scale_factor=8):
+ batch_size, num_patches, channels = latents.shape
+
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+class LanPaint():
+ def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False):
+ self.n_steps = NSteps
+ self.chara_lamb = Lambda
+ self.IS_FLUX = IS_FLUX
+ self.IS_FLOW = IS_FLOW
+ self.step_size = StepSize
+ self.friction = Friction
+ self.chara_beta = Beta
+ self.img_dim_size = None
+ def add_none_dims(self, array):
+ # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
+ index = (slice(None),) + (None,) * (self.img_dim_size-1)
+ return array[index]
+ def remove_none_dims(self, array):
+ # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times
+ index = (slice(None),) + (0,) * (self.img_dim_size-1)
+ return array[index]
+ def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8):
+ latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor)
+ noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor)
+ x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor)
+ latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor)
+ self.height = height
+ self.width = width
+ self.vae_scale_factor = vae_scale_factor
+ self.img_dim_size = len(x.shape)
+ self.latent_image = latent_image
+ self.noise = noise
+ if n_steps is None:
+ n_steps = self.n_steps
+ out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW)
+ out = _pack_latents(out)
+ return out
+ def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW):
+ if IS_FLUX:
+ cfg_BIG = 1.0
+
+ def double_denoise(latents, t):
+ latents = _pack_latents(latents)
+ noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale)
+ if noise_pred == None: return None, None
+ predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t)
+ predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor)
+ if true_cfg_scale == cfg_BIG:
+ predict_big = predict_std
+ else:
+ predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t)
+ predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor)
+ return predict_std, predict_big
+
+ if len(sigma.shape) == 0:
+ sigma = torch.tensor([sigma.item()])
+ latent_mask = 1 - latent_mask
+ if IS_FLUX or IS_FLOW:
+ Flow_t = sigma
+ abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 )
+ VE_Sigma = Flow_t / (1 - Flow_t)
+ #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item())
+ else:
+ VE_Sigma = sigma
+ abt = 1/( 1+VE_Sigma**2 )
+ Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 )
+ # VE_Sigma, abt, Flow_t = current_times
+ current_times = (VE_Sigma, abt, Flow_t)
+
+ step_size = self.step_size * (1 - abt)
+ step_size = self.add_none_dims(step_size)
+ # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values
+ # This is the replace step
+ # x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask
+
+ noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma
+ x = x * (1 - latent_mask) + noisy_image * latent_mask
+
+ if IS_FLUX or IS_FLOW:
+ x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
+ else:
+ x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
+
+ ############ LanPaint Iterations Start ###############
+ # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region.
+ args = None
+ for i in range(n_steps):
+ score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise )
+ if score_func is None: return None
+ x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args)
+ if IS_FLUX or IS_FLOW:
+ x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 )
+ else:
+ x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values
+ ############ LanPaint Iterations End ###############
+ # out is x_0
+ # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed)
+ # out = out * (1-latent_mask) + self.latent_image * latent_mask
+ # return out
+ return x
+
+ def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func):
+
+ lamb = self.chara_lamb
+ if self.IS_FLUX or self.IS_FLOW:
+ # compute t for flow model, with a small epsilon compensating for numerical error.
+ x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching
+ x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow))
+ if x_0 is None: return None
+ else:
+ x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding
+ x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma))
+ if x_0 is None: return None
+
+ score_x = -(x_t - x_0)
+ score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG)
+ return score_x * (1 - mask) + score_y * mask
+ def sigma_x(self, abt):
+ # the time scale for the x_t update
+ return abt**0
+ def sigma_y(self, abt):
+ beta = self.chara_beta * abt ** 0
+ return beta
+
+ def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None):
+ # prepare the step size and time parameters
+ with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
+ step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y)
+ sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes
+ # print('mask',mask.device)
+ if torch.mean(dtx) <= 0.:
+ return x_t, args
+ # -------------------------------------------------------------------------
+ # Compute the Langevin dynamics update in variance perserving notation
+ # -------------------------------------------------------------------------
+ #x0 = self.x0_evalutation(x_t, score, sigma, args)
+ #C = abt**0.5 * x0 / (1-abt)
+ A = A_x * (1-mask) + A_y * mask
+ D = D_x * (1-mask) + D_y * mask
+ dt = dtx * (1-mask) + dty * mask
+ Gamma = Gamma_x * (1-mask) + Gamma_y * mask
+
+
+ def Coef_C(x_t):
+ x0 = self.x0_evalutation(x_t, score, sigma, args)
+ C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t
+ return C
+ def advance_time(x_t, v, dt, Gamma, A, C, D):
+ dtype = x_t.dtype
+ with torch.autocast(device_type=x_t.device.type, dtype=torch.float32):
+ osc = StochasticHarmonicOscillator(Gamma, A, C, D )
+ x_t, v = osc.dynamics(x_t, v, dt )
+ x_t = x_t.to(dtype)
+ v = v.to(dtype)
+ return x_t, v
+ if args is None:
+ #v = torch.zeros_like(x_t)
+ v = None
+ C = Coef_C(x_t)
+ #print(torch.squeeze(dtx), torch.squeeze(dty))
+ x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D)
+ else:
+ v, C = args
+
+ x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
+
+ C_new = Coef_C(x_t)
+ v = v + Gamma**0.5 * ( C_new - C) *dt
+
+ x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D)
+
+ C = C_new
+
+ return x_t, (v, C)
+
+ def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y):
+ # -------------------------------------------------------------------------
+ # Unpack current times parameters (sigma and abt)
+ sigma, abt, flow_t = current_times
+ sigma = self.add_none_dims(sigma)
+ abt = self.add_none_dims(abt)
+ # Compute time step (dtx, dty) for x and y branches.
+ dtx = 2 * step_size * sigma_x
+ dty = 2 * step_size * sigma_y
+
+ # -------------------------------------------------------------------------
+ # Define friction parameter Gamma_hat for each branch.
+ # Using dtx**0 provides a tensor of the proper device/dtype.
+
+ Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0
+ Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0
+ #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item())
+ # adjust dt to match denoise-addnoise steps sizes
+ Gamma_hat_x /= 2.
+ Gamma_hat_y /= 2.
+ A_t_x = (1) / ( 1 - abt ) * dtx / 2
+ A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2
+
+
+ A_x = A_t_x / (dtx/2)
+ A_y = A_t_y / (dty/2)
+ Gamma_x = Gamma_hat_x / (dtx/2)
+ Gamma_y = Gamma_hat_y / (dty/2)
+
+ #D_x = (2 * (1 + sigma**2) )**0.5
+ #D_y = (2 * (1 + sigma**2) )**0.5
+ D_x = (2 * abt**0 )**0.5
+ D_y = (2 * abt**0 )**0.5
+ return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y
+
+
+
+ def x0_evalutation(self, x_t, score, sigma, args):
+ x0 = x_t + score(x_t)
+ return x0
\ No newline at end of file
diff --git a/shared/inpainting/utils.py b/shared/inpainting/utils.py
new file mode 100644
index 000000000..c017ab0a8
--- /dev/null
+++ b/shared/inpainting/utils.py
@@ -0,0 +1,301 @@
+import torch
+def epxm1_x(x):
+ # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero.
+ result = torch.special.expm1(x) / x
+ # replace NaN or inf values with 0
+ result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
+ mask = torch.abs(x) < 1e-2
+ result = torch.where(mask, 1 + x/2. + x**2 / 6., result)
+ return result
+def epxm1mx_x2(x):
+ # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero.
+ result = (torch.special.expm1(x) - x) / x**2
+ # replace NaN or inf values with 0
+ result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
+ mask = torch.abs(x**2) < 1e-2
+ result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result)
+ return result
+
+def expm1mxmhx2_x3(x):
+ # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero.
+ result = (torch.special.expm1(x) - x - x**2 / 2) / x**3
+ # replace NaN or inf values with 0
+ result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
+ mask = torch.abs(x**3) < 1e-2
+ result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result)
+ return result
+
+def exp_1mcosh_GD(gamma_t, delta):
+ """
+ Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
+
+ Parameters:
+ gamma_t: Γ*t term (could be a scalar or tensor)
+ delta: Δ term (could be a scalar or tensor)
+
+ Returns:
+ Result of the computation with numerical stability handling
+ """
+ # Main computation
+ is_positive = delta > 0
+ sqrt_abs_delta = torch.sqrt(torch.abs(delta))
+ gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
+ numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
+ numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) )
+ numerator = torch.where(is_positive, numerator_pos, numerator_neg)
+ result = numerator / (delta * gamma_t**2 )
+ # Handle NaN/inf cases
+ result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
+ # Handle numerical instability for small delta
+ mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2
+ taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t)
+ result = torch.where(mask, taylor, result)
+ return result
+
+def exp_sinh_GsqrtD(gamma_t, delta):
+ """
+ Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
+
+ Parameters:
+ gamma_t: Γ*t term (could be a scalar or tensor)
+ delta: Δ term (could be a scalar or tensor)
+
+ Returns:
+ Result of the computation with numerical stability handling
+ """
+ # Main computation
+ is_positive = delta > 0
+ sqrt_abs_delta = torch.sqrt(torch.abs(delta))
+ gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta
+ numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2
+ denominator_pos = gamma_t_sqrt_delta
+ result_pos = numerator_pos / gamma_t_sqrt_delta
+ result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos))
+
+ # Taylor expansion for small gamma_t_sqrt_delta
+ mask = torch.abs(gamma_t_sqrt_delta) < 1e-2
+ taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t)
+ result_pos = torch.where(mask, taylor, result_pos)
+
+ # Handle negative delta
+ result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi)
+ result = torch.where(is_positive, result_pos, result_neg)
+ return result
+
+def exp_cosh(gamma_t, delta):
+ """
+ Compute e^(-Γt) * cosh(Γt√Δ)
+
+ Parameters:
+ gamma_t: Γ*t term (could be a scalar or tensor)
+ delta: Δ term (could be a scalar or tensor)
+
+ Returns:
+ Result of the computation with numerical stability handling
+ """
+ exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
+ result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result
+ return result
+def exp_sinh_sqrtD(gamma_t, delta):
+ """
+ Compute e^(-Γt) * sinh(Γt√Δ) / √Δ
+ Parameters:
+ gamma_t: Γ*t term (could be a scalar or tensor)
+ delta: Δ term (could be a scalar or tensor)
+ Returns:
+ Result of the computation with numerical stability handling
+ """
+ exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ)
+ result = gamma_t * exp_sinh_GsqrtD_result
+ return result
+
+
+
+def zeta1(gamma_t, delta):
+ # Compute hyperbolic terms and exponential
+ half_gamma_t = gamma_t / 2
+ exp_cosh_term = exp_cosh(half_gamma_t, delta)
+ exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta)
+
+
+ # Main computation
+ numerator = 1 - (exp_cosh_term + exp_sinh_term)
+ denominator = gamma_t * (1 - delta) / 4
+ result = 1 - numerator / denominator
+
+ # Handle numerical instability
+ result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
+
+ # Taylor expansion for small x (similar to your epxm1Dx approach)
+ mask = torch.abs(denominator) < 5e-3
+ term1 = epxm1_x(-gamma_t)
+ term2 = epxm1mx_x2(-gamma_t)
+ term3 = expm1mxmhx2_x3(-gamma_t)
+ taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2
+ result = torch.where(mask, taylor, result)
+
+ return result
+
+def exp_cosh_minus_terms(gamma_t, delta):
+ """
+ Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ))
+
+ Parameters:
+ gamma_t: Γ*t term (could be a scalar or tensor)
+ delta: Δ term (could be a scalar or tensor)
+
+ Returns:
+ Result of the computation with numerical stability handling
+ """
+ exp_term = torch.exp(-gamma_t)
+ # Compute individual terms
+ exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term
+ exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term
+
+ #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ )
+ # Main computation
+ numerator = exp_cosh_term - exp_cosh_delta_term
+ denominator = gamma_t * (1 - delta)
+
+ result = numerator / denominator
+
+ # Handle numerical instability
+ result = torch.where(torch.isfinite(result), result, torch.zeros_like(result))
+
+ # Taylor expansion for small gamma_t and delta near 1
+ mask = (torch.abs(denominator) < 1e-1)
+ exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0)
+ taylor = (
+ gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0)
+ - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) )
+ )
+ result = torch.where(mask, taylor, result)
+
+ return result
+
+
+def zeta2(gamma_t, delta):
+ half_gamma_t = gamma_t / 2
+ return exp_sinh_GsqrtD(half_gamma_t, delta)
+
+def sig11(gamma_t, delta):
+ return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
+
+
+def Zcoefs(gamma_t, delta):
+ Zeta1 = zeta1(gamma_t, delta)
+ Zeta2 = zeta2(gamma_t, delta)
+
+ sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8
+ amplitude = torch.sqrt(sq_total)
+ Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude
+ Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5
+ #cterm = exp_cosh_minus_terms(gamma_t, delta)
+ #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta)
+ #Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) )
+ Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) )
+
+ return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude
+
+def Zcoefs_asymp(gamma_t, delta):
+ A_t = (gamma_t * (1 - delta) )/4
+ return epxm1_x(- 2 * A_t)
+
+class StochasticHarmonicOscillator:
+ """
+ Simulates a stochastic harmonic oscillator governed by the equations:
+ dy(t) = q(t) dt
+ dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt
+
+ Also define v(t) = q(t) / √Γ, which is numerically more stable.
+
+ Where:
+ y(t) - Position variable
+ q(t) - Velocity variable
+ Γ - Damping coefficient
+ A - Harmonic potential strength
+ C - Constant force term
+ D - Noise amplitude
+ dw(t) - Wiener process (Brownian motion)
+ """
+ def __init__(self, Gamma, A, C, D):
+ self.Gamma = Gamma
+ self.A = A
+ self.C = C
+ self.D = D
+ self.Delta = 1 - 4 * A / Gamma
+ def sig11(self, gamma_t, delta):
+ return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta)
+ def sig22(self, gamma_t, delta):
+ return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta)
+ def dynamics(self, y0, v0, t):
+ """
+ Calculates the position and velocity variables at time t.
+
+ Parameters:
+ y0 (float): Initial position
+ v0 (float): Initial velocity v(0) = q(0) / √Γ
+ t (float): Time at which to evaluate the dynamics
+ Returns:
+ tuple: (y(t), v(t))
+ """
+
+ dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0
+ Delta = self.Delta + dummyzero
+ Gamma_hat = self.Gamma * t + dummyzero
+ A = self.A + dummyzero
+ C = self.C + dummyzero
+ D = self.D + dummyzero
+ Gamma = self.Gamma + dummyzero
+ zeta_1 = zeta1( Gamma_hat, Delta)
+ zeta_2 = zeta2( Gamma_hat, Delta)
+ EE = 1 - Gamma_hat * zeta_2
+
+ if v0 is None:
+ v0 = torch.randn_like(y0) * D / 2 ** 0.5
+ #v0 = (C - A * y0)/Gamma**0.5
+
+ # Calculate mean position and velocity
+ term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t
+ y_mean = term1 + y0
+ v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0
+
+ cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta)
+ cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2
+ cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5)
+
+ # sample new position and velocity with multivariate normal distribution
+
+ batch_shape = y0.shape
+ cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
+ cov_matrix[..., 0, 0] = cov_yy
+ cov_matrix[..., 0, 1] = cov_yv
+ cov_matrix[..., 1, 0] = cov_yv # symmetric
+ cov_matrix[..., 1, 1] = cov_vv
+
+
+
+ # Compute the Cholesky decomposition to get scale_tril
+ #scale_tril = torch.linalg.cholesky(cov_matrix)
+ scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype)
+ tol = 1e-8
+ cov_yy = torch.clamp( cov_yy, min = tol )
+ sd_yy = torch.sqrt( cov_yy )
+ inv_sd_yy = 1/(sd_yy)
+
+ scale_tril[..., 0, 0] = sd_yy
+ scale_tril[..., 0, 1] = 0.
+ scale_tril[..., 1, 0] = cov_yv * inv_sd_yy
+ scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5
+ # check if it matches torch.linalg.
+ #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 )
+ # Sample correlated noise from multivariate normal
+ mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype)
+ mean[..., 0] = y_mean
+ mean[..., 1] = v_mean
+ new_yv = torch.distributions.MultivariateNormal(
+ loc=mean,
+ scale_tril=scale_tril
+ ).sample()
+
+ return new_yv[...,0], new_yv[...,1]
\ No newline at end of file
diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py
index b24530d5c..224cf3412 100644
--- a/shared/utils/audio_video.py
+++ b/shared/utils/audio_video.py
@@ -232,6 +232,9 @@ def save_video(tensor,
retry=5):
"""Save tensor as video with configurable codec and container options."""
+ if torch.is_tensor(tensor) and len(tensor.shape) == 4:
+ tensor = tensor.unsqueeze(0)
+
suffix = f'.{container}'
cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file
if not cache_file.endswith(suffix):
diff --git a/shared/utils/download.py b/shared/utils/download.py
new file mode 100644
index 000000000..ed035c0b7
--- /dev/null
+++ b/shared/utils/download.py
@@ -0,0 +1,110 @@
+import sys, time
+
+# Global variables to track download progress
+_start_time = None
+_last_time = None
+_last_downloaded = 0
+_speed_history = []
+_update_interval = 0.5 # Update speed every 0.5 seconds
+
+def progress_hook(block_num, block_size, total_size, filename=None):
+ """
+ Simple progress bar hook for urlretrieve
+
+ Args:
+ block_num: Number of blocks downloaded so far
+ block_size: Size of each block in bytes
+ total_size: Total size of the file in bytes
+ filename: Name of the file being downloaded (optional)
+ """
+ global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
+
+ current_time = time.time()
+ downloaded = block_num * block_size
+
+ # Initialize timing on first call
+ if _start_time is None or block_num == 0:
+ _start_time = current_time
+ _last_time = current_time
+ _last_downloaded = 0
+ _speed_history = []
+
+ # Calculate download speed only at specified intervals
+ speed = 0
+ if current_time - _last_time >= _update_interval:
+ if _last_time > 0:
+ current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
+ _speed_history.append(current_speed)
+ # Keep only last 5 speed measurements for smoothing
+ if len(_speed_history) > 5:
+ _speed_history.pop(0)
+ # Average the recent speeds for smoother display
+ speed = sum(_speed_history) / len(_speed_history)
+
+ _last_time = current_time
+ _last_downloaded = downloaded
+ elif _speed_history:
+ # Use the last calculated average speed
+ speed = sum(_speed_history) / len(_speed_history)
+ # Format file sizes and speed
+ def format_bytes(bytes_val):
+ for unit in ['B', 'KB', 'MB', 'GB']:
+ if bytes_val < 1024:
+ return f"{bytes_val:.1f}{unit}"
+ bytes_val /= 1024
+ return f"{bytes_val:.1f}TB"
+
+ file_display = filename if filename else "Unknown file"
+
+ if total_size <= 0:
+ # If total size is unknown, show downloaded bytes
+ speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
+ line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
+ # Clear any trailing characters by padding with spaces
+ sys.stdout.write(line.ljust(80))
+ sys.stdout.flush()
+ return
+
+ downloaded = block_num * block_size
+ percent = min(100, (downloaded / total_size) * 100)
+
+ # Create progress bar (40 characters wide to leave room for other info)
+ bar_length = 40
+ filled = int(bar_length * percent / 100)
+ bar = '█' * filled + '░' * (bar_length - filled)
+
+ # Format file sizes and speed
+ def format_bytes(bytes_val):
+ for unit in ['B', 'KB', 'MB', 'GB']:
+ if bytes_val < 1024:
+ return f"{bytes_val:.1f}{unit}"
+ bytes_val /= 1024
+ return f"{bytes_val:.1f}TB"
+
+ speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
+
+ # Display progress with filename first
+ line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
+ # Clear any trailing characters by padding with spaces
+ sys.stdout.write(line.ljust(100))
+ sys.stdout.flush()
+
+ # Print newline when complete
+ if percent >= 100:
+ print()
+
+# Wrapper function to include filename in progress hook
+def create_progress_hook(filename):
+ """Creates a progress hook with the filename included"""
+ global _start_time, _last_time, _last_downloaded, _speed_history
+ # Reset timing variables for new download
+ _start_time = None
+ _last_time = None
+ _last_downloaded = 0
+ _speed_history = []
+
+ def hook(block_num, block_size, total_size):
+ return progress_hook(block_num, block_size, total_size, filename)
+ return hook
+
+
diff --git a/shared/utils/utils.py b/shared/utils/utils.py
index 17b9dde41..6e8a98bf2 100644
--- a/shared/utils/utils.py
+++ b/shared/utils/utils.py
@@ -1,4 +1,3 @@
-# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import os.path as osp
@@ -176,8 +175,9 @@ def remove_background(img, session=None):
def convert_image_to_tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
-def convert_tensor_to_image(t, frame_no = -1):
- t = t[:, frame_no] if frame_no >= 0 else t
+def convert_tensor_to_image(t, frame_no = 0):
+ if len(t.shape) == 4:
+ t = t[:, frame_no]
return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1):
@@ -257,16 +257,18 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
return image, new_height, new_width
-def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ):
+def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
if rm_background:
session = new_session()
output_list =[]
+ output_mask_list =[]
for i, img in enumerate(img_list):
width, height = img.size
+ resized_mask = None
if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
- if outpainting_dims is not None:
- resized_image =img
+ if outpainting_dims is not None and background_ref_outpainted:
+ resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
elif img.size != (budget_width, budget_height):
resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS)
else:
@@ -290,145 +292,103 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
- return output_list
+ output_mask_list.append(resized_mask)
+ return output_list, output_mask_list
+def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False):
+ from shared.utils.utils import save_image
+ inpaint_color = canvas_tf_bg / 127.5 - 1
-
-
-def str2bool(v):
- """
- Convert a string to a boolean.
-
- Supported true values: 'yes', 'true', 't', 'y', '1'
- Supported false values: 'no', 'false', 'f', 'n', '0'
-
- Args:
- v (str): String to convert.
-
- Returns:
- bool: Converted boolean value.
-
- Raises:
- argparse.ArgumentTypeError: If the value cannot be converted to boolean.
- """
- if isinstance(v, bool):
- return v
- v_lower = v.lower()
- if v_lower in ('yes', 'true', 't', 'y', '1'):
- return True
- elif v_lower in ('no', 'false', 'f', 'n', '0'):
- return False
+ ref_width, ref_height = ref_img.size
+ if (ref_height, ref_width) == image_size and outpainting_dims == None:
+ ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
+ canvas = torch.zeros_like(ref_img) if return_mask else None
else:
- raise argparse.ArgumentTypeError('Boolean value expected (True/False)')
-
-
-import sys, time
-
-# Global variables to track download progress
-_start_time = None
-_last_time = None
-_last_downloaded = 0
-_speed_history = []
-_update_interval = 0.5 # Update speed every 0.5 seconds
-
-def progress_hook(block_num, block_size, total_size, filename=None):
- """
- Simple progress bar hook for urlretrieve
-
- Args:
- block_num: Number of blocks downloaded so far
- block_size: Size of each block in bytes
- total_size: Total size of the file in bytes
- filename: Name of the file being downloaded (optional)
- """
- global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval
-
- current_time = time.time()
- downloaded = block_num * block_size
-
- # Initialize timing on first call
- if _start_time is None or block_num == 0:
- _start_time = current_time
- _last_time = current_time
- _last_downloaded = 0
- _speed_history = []
-
- # Calculate download speed only at specified intervals
- speed = 0
- if current_time - _last_time >= _update_interval:
- if _last_time > 0:
- current_speed = (downloaded - _last_downloaded) / (current_time - _last_time)
- _speed_history.append(current_speed)
- # Keep only last 5 speed measurements for smoothing
- if len(_speed_history) > 5:
- _speed_history.pop(0)
- # Average the recent speeds for smoother display
- speed = sum(_speed_history) / len(_speed_history)
-
- _last_time = current_time
- _last_downloaded = downloaded
- elif _speed_history:
- # Use the last calculated average speed
- speed = sum(_speed_history) / len(_speed_history)
- # Format file sizes and speed
- def format_bytes(bytes_val):
- for unit in ['B', 'KB', 'MB', 'GB']:
- if bytes_val < 1024:
- return f"{bytes_val:.1f}{unit}"
- bytes_val /= 1024
- return f"{bytes_val:.1f}TB"
-
- file_display = filename if filename else "Unknown file"
-
- if total_size <= 0:
- # If total size is unknown, show downloaded bytes
- speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
- line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}"
- # Clear any trailing characters by padding with spaces
- sys.stdout.write(line.ljust(80))
- sys.stdout.flush()
- return
-
- downloaded = block_num * block_size
- percent = min(100, (downloaded / total_size) * 100)
-
- # Create progress bar (40 characters wide to leave room for other info)
- bar_length = 40
- filled = int(bar_length * percent / 100)
- bar = '█' * filled + '░' * (bar_length - filled)
-
- # Format file sizes and speed
- def format_bytes(bytes_val):
- for unit in ['B', 'KB', 'MB', 'GB']:
- if bytes_val < 1024:
- return f"{bytes_val:.1f}{unit}"
- bytes_val /= 1024
- return f"{bytes_val:.1f}TB"
-
- speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else ""
-
- # Display progress with filename first
- line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}"
- # Clear any trailing characters by padding with spaces
- sys.stdout.write(line.ljust(100))
- sys.stdout.flush()
-
- # Print newline when complete
- if percent >= 100:
- print()
-
-# Wrapper function to include filename in progress hook
-def create_progress_hook(filename):
- """Creates a progress hook with the filename included"""
- global _start_time, _last_time, _last_downloaded, _speed_history
- # Reset timing variables for new download
- _start_time = None
- _last_time = None
- _last_downloaded = 0
- _speed_history = []
-
- def hook(block_num, block_size, total_size):
- return progress_hook(block_num, block_size, total_size, filename)
- return hook
+ if outpainting_dims != None:
+ final_height, final_width = image_size
+ canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
+ else:
+ canvas_height, canvas_width = image_size
+ if full_frame:
+ new_height = canvas_height
+ new_width = canvas_width
+ top = left = 0
+ else:
+ # if fill_max and (canvas_height - new_height) < 16:
+ # new_height = canvas_height
+ # if fill_max and (canvas_width - new_width) < 16:
+ # new_width = canvas_width
+ scale = min(canvas_height / ref_height, canvas_width / ref_width)
+ new_height = int(ref_height * scale)
+ new_width = int(ref_width * scale)
+ top = (canvas_height - new_height) // 2
+ left = (canvas_width - new_width) // 2
+ ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
+ ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
+ if outpainting_dims != None:
+ canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
+ canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img
+ else:
+ canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1]
+ canvas[:, :, top:top + new_height, left:left + new_width] = ref_img
+ ref_img = canvas
+ canvas = None
+ if return_mask:
+ if outpainting_dims != None:
+ canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1]
+ canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0
+ else:
+ canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1]
+ canvas[:, :, top:top + new_height, left:left + new_width] = 0
+ canvas = canvas.to(device)
+ if return_image:
+ return convert_tensor_to_image(ref_img), canvas
+
+ return ref_img.to(device), canvas
+
+def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
+ src_videos, src_masks = [], []
+ inpaint_color = guide_inpaint_color/127.5 - 1
+ prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
+ for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
+ src_video = src_mask = None
+ if cur_video_guide is not None:
+ src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
+ if cur_video_mask is not None and any_mask:
+ src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
+ if pre_video_guide is not None and not extract_guide_from_window_start:
+ src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
+ if any_mask:
+ src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
+ if src_video is None:
+ if any_guide_padding:
+ src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
+ if any_mask:
+ src_mask = torch.zeros_like(src_video[0:1])
+ elif src_video.shape[1] < current_video_length:
+ if any_guide_padding:
+ src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
+ if cur_video_mask is not None and any_mask:
+ src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
+ else:
+ new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
+ src_video = src_video[:, :new_num_frames]
+ if any_mask:
+ src_mask = src_mask[:, :new_num_frames]
+
+ for k, keep in enumerate(keep_video_guide_frames):
+ if not keep:
+ pos = prepend_count + k
+ src_video[:, pos:pos+1] = inpaint_color
+ src_mask[:, pos:pos+1] = 1
+
+ for k, frame in enumerate(inject_frames):
+ if frame != None:
+ pos = prepend_count + k
+ src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
+
+ src_videos.append(src_video)
+ src_masks.append(src_mask)
+ return src_videos, src_masks
diff --git a/wgp.py b/wgp.py
index ee6b5f62e..0e6a0b20e 100644
--- a/wgp.py
+++ b/wgp.py
@@ -394,7 +394,7 @@ def ret():
gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long")
if "F" in video_prompt_type:
if len(frames_positions.strip()) > 0:
- positions = frames_positions.split(" ")
+ positions = frames_positions.replace(","," ").split(" ")
for pos_str in positions:
if not pos_str in ["L", "l"] and len(pos_str)>0:
if not is_integer(pos_str):
@@ -2528,7 +2528,7 @@ def computeList(filename):
from urllib.request import urlretrieve
- from shared.utils.utils import create_progress_hook
+ from shared.utils.download import create_progress_hook
shared_def = {
"repoId" : "DeepBeepMeep/Wan2.1",
@@ -3726,6 +3726,60 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva
input_mask = convert_tensor_to_image(full_frame)
return input_image, input_mask
+def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
+ if not input_video_path or max_frames <= 0:
+ return None, None
+ pad_frames = 0
+ if start_frame < 0:
+ pad_frames= -start_frame
+ max_frames += start_frame
+ start_frame = 0
+
+ any_mask = input_mask_path != None
+ video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
+ if len(video) == 0: return None
+ if any_mask:
+ mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
+ frame_height, frame_width, _ = video[0].shape
+
+ num_frames = min(len(video), len(mask_video))
+ if num_frames == 0: return None
+ video, mask_video = video[:num_frames], mask_video[:num_frames]
+
+ from preprocessing.face_preprocessor import FaceProcessor
+ face_processor = FaceProcessor()
+
+ face_list = []
+ for frame_idx in range(num_frames):
+ frame = video[frame_idx].cpu().numpy()
+ # video[frame_idx] = None
+ if any_mask:
+ mask = Image.fromarray(mask_video[frame_idx].cpu().numpy())
+ # mask_video[frame_idx] = None
+ if (frame_width, frame_height) != mask.size:
+ mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS)
+ mask = np.array(mask)
+ alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8)
+ alpha_mask[mask > 127] = 1
+ frame = frame * alpha_mask
+ frame = Image.fromarray(frame)
+ face = face_processor.process(frame, resize_to=size, face_crop_scale = 1)
+ face_list.append(face)
+
+ face_processor = None
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w
+ if pad_frames > 0:
+ face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2)
+
+ if args.save_masks:
+ from preprocessing.dwpose.pose import save_one_video
+ saved_faces_frames = [np.array(face) for face in face_list ]
+ save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
+ return face_tensor
+
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
@@ -3742,7 +3796,13 @@ def mask_to_xyxy_box(mask):
box = [xmin, ymin, xmax, ymax]
box = [int(x) for x in box]
return box
-
+ inpaint_color = int(inpaint_color)
+ pad_frames = 0
+ if start_frame < 0:
+ pad_frames= -start_frame
+ max_frames += start_frame
+ start_frame = 0
+
if not input_video_path or max_frames <= 0:
return None, None
any_mask = input_mask_path != None
@@ -3909,6 +3969,9 @@ def prep_prephase(frame_idx):
preproc_outside = None
gc.collect()
torch.cuda.empty_cache()
+ if pad_frames > 0:
+ masked_frames = masked_frames[0] * pad_frames + masked_frames
+ if any_mask: masked_frames = masks[0] * pad_frames + masks
return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
@@ -4646,7 +4709,8 @@ def remove_temp_filenames(temp_filenames_list):
current_video_length = video_length
# VAE Tiling
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
-
+ guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
+ extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
i2v = test_class_i2v(model_type)
diffusion_forcing = "diffusion_forcing" in model_filename
t2v = base_model_type in ["t2v"]
@@ -4662,6 +4726,7 @@ def remove_temp_filenames(temp_filenames_list):
multitalk = model_def.get("multitalk_class", False)
standin = model_def.get("standin_class", False)
infinitetalk = base_model_type in ["infinitetalk"]
+ animate = base_model_type in ["animate"]
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations
@@ -4822,7 +4887,6 @@ def remove_temp_filenames(temp_filenames_list):
repeat_no = 0
extra_generation = 0
initial_total_windows = 0
-
discard_last_frames = sliding_window_discard_last_frames
default_requested_frames_to_generate = current_video_length
if sliding_window:
@@ -4843,7 +4907,7 @@ def remove_temp_filenames(temp_filenames_list):
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
- src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None
+ src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
@@ -4899,7 +4963,7 @@ def remove_temp_filenames(temp_filenames_list):
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
- src_ref_images = image_refs
+ src_ref_images, src_ref_masks = image_refs, None
image_start_tensor = image_end_tensor = None
if window_no == 1 and (video_source is not None or image_start is not None):
if image_start is not None:
@@ -4943,16 +5007,52 @@ def remove_temp_filenames(temp_filenames_list):
from models.wan.multitalk.multitalk import get_window_audio_embeddings
# special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding)
audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length)
- if vace:
- video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
+
+ if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0:
+ frames_positions_list = []
+ if frames_positions is not None and len(frames_positions)> 0:
+ positions = frames_positions.replace(","," ").split(" ")
+ cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count)
+ last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
+ joker_used = False
+ project_window_no = 1
+ for pos in positions :
+ if len(pos) > 0:
+ if pos in ["L", "l"]:
+ cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
+ if cur_end_pos >= last_frame_no and not joker_used:
+ joker_used = True
+ cur_end_pos = last_frame_no -1
+ project_window_no += 1
+ frames_positions_list.append(cur_end_pos)
+ cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
+ else:
+ frames_positions_list.append(int(pos)-1 + alignment_shift)
+ frames_positions_list = frames_positions_list[:len(image_refs)]
+ nb_frames_positions = len(frames_positions_list)
+ if nb_frames_positions > 0:
+ frames_to_inject = [None] * (max(frames_positions_list) + 1)
+ for i, pos in enumerate(frames_positions_list):
+ frames_to_inject[pos] = image_refs[i]
+
if video_guide is not None:
- keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
+ keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
- keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ]
-
- if vace:
+ guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
+ keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
+ keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
+ guide_frames_extract_count = len(keep_frames_parsed)
+
+ if "B" in video_prompt_type:
+ send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
+ src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
+ if src_faces is not None and src_faces.shape[1] < current_video_length:
+ src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
+
+ if vace or animate:
+ video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
context_scale = [ control_net_weight]
if "V" in video_prompt_type:
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
@@ -4971,10 +5071,10 @@ def remove_temp_filenames(temp_filenames_list):
if preprocess_type2 is not None:
context_scale = [ control_net_weight /2, control_net_weight2 /2]
send_cmd("progress", [0, get_latest_status(state, status_info)])
- inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask=="inpaint" else 127
- video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
+ inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
+ video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
if preprocess_type2 != None:
- video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
+ video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
if video_guide_processed != None:
if sample_fit_canvas != None:
@@ -4985,7 +5085,37 @@ def remove_temp_filenames(temp_filenames_list):
refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
if video_mask_processed != None:
refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
- elif ltxv:
+
+ frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
+
+ if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
+ any_mask = True
+ any_guide_padding = model_def.get("pad_guide_video", False)
+ from shared.utils.utils import prepare_video_guide_and_mask
+ src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
+ [video_mask_processed, video_mask_processed2],
+ pre_video_guide, image_size, current_video_length, latent_size,
+ any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
+ keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
+
+ src_video, src_video2 = src_videos
+ src_mask, src_mask2 = src_masks
+ if src_video is None:
+ abort = True
+ break
+ if src_faces is not None:
+ if src_faces.shape[1] < src_video.shape[1]:
+ src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
+ else:
+ src_faces = src_faces[:, :src_video.shape[1]]
+ if args.save_masks:
+ save_video( src_video, "masked_frames.mp4", fps)
+ if src_video2 is not None:
+ save_video( src_video2, "masked_frames2.mp4", fps)
+ if any_mask:
+ save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
+
+ elif ltxv:
preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
status_info = "Extracting " + processes_names[preprocess_type]
send_cmd("progress", [0, get_latest_status(state, status_info)])
@@ -5023,7 +5153,7 @@ def remove_temp_filenames(temp_filenames_list):
sample_fit_canvas = None
else: # video to video
- video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps)
+ video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
if video_guide_processed is None:
src_video = pre_video_guide
else:
@@ -5043,29 +5173,6 @@ def remove_temp_filenames(temp_filenames_list):
refresh_preview["image_mask"] = new_image_mask
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
- if repeat_no == 1:
- frames_positions_list = []
- if frames_positions is not None and len(frames_positions)> 0:
- positions = frames_positions.split(" ")
- cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0
- last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count
- joker_used = False
- project_window_no = 1
- for pos in positions :
- if len(pos) > 0:
- if pos in ["L", "l"]:
- cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
- if cur_end_pos >= last_frame_no and not joker_used:
- joker_used = True
- cur_end_pos = last_frame_no -1
- project_window_no += 1
- frames_positions_list.append(cur_end_pos)
- cur_end_pos -= sliding_window_discard_last_frames + reuse_frames
- else:
- frames_positions_list.append(int(pos)-1 + alignment_shift)
- frames_positions_list = frames_positions_list[:len(image_refs)]
- nb_frames_positions = len(frames_positions_list)
-
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
from shared.utils.utils import get_outpainting_full_area_dimensions
w, h = image_refs[0].size
@@ -5089,20 +5196,16 @@ def remove_temp_filenames(temp_filenames_list):
if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
# keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
- image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
+ image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
remove_background_images_ref > 0, any_background_ref,
fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
block_size=block_size,
- outpainting_dims =outpainting_dims )
+ outpainting_dims =outpainting_dims,
+ background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
refresh_preview["image_refs"] = image_refs
- if nb_frames_positions > 0:
- frames_to_inject = [None] * (max(frames_positions_list) + 1)
- for i, pos in enumerate(frames_positions_list):
- frames_to_inject[pos] = image_refs[i]
if vace :
- frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame]
image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
@@ -5116,7 +5219,7 @@ def remove_temp_filenames(temp_filenames_list):
any_background_ref = any_background_ref
)
if len(frames_to_inject_parsed) or any_background_ref:
- new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
+ new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
if any_background_ref:
new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
else:
@@ -5165,10 +5268,14 @@ def set_header_text(txt):
input_prompt = prompt,
image_start = image_start_tensor,
image_end = image_end_tensor,
- input_frames = src_video,
+ input_frames = src_video,
+ input_frames2 = src_video2,
input_ref_images= src_ref_images,
+ input_ref_masks = src_ref_masks,
input_masks = src_mask,
+ input_masks2 = src_mask2,
input_video= pre_video_guide,
+ input_faces = src_faces,
denoising_strength=denoising_strength,
prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames,
frame_num= (current_video_length // latent_size)* latent_size + 1,
@@ -5302,6 +5409,7 @@ def set_header_text(txt):
send_cmd("output")
else:
sample = samples.cpu()
+ abort = not is_image and sample.shape[1] < current_video_length
# if True: # for testing
# torch.save(sample, "output.pt")
# else:
@@ -6980,7 +7088,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
old_video_prompt_type = video_prompt_type
- video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV")
+ video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB")
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
model_type = state["model_type"]
@@ -7437,7 +7545,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False)
image_prompt_type_choices = []
if "T" in image_prompt_types_allowed:
- image_prompt_type_choices += [("Text Prompt Only", "")]
+ image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")]
if "S" in image_prompt_types_allowed:
image_prompt_type_choices += [("Start Video with Image", "S")]
any_start_image = True
@@ -7516,7 +7624,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
video_prompt_type_video_guide = gr.Dropdown(
guide_preprocessing_choices,
- value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ),
+ value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ),
label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
)
any_control_video = True
@@ -7560,13 +7668,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
}
mask_preprocessing_choices = []
- mask_preprocessing_labels = guide_preprocessing.get("labels", {})
+ mask_preprocessing_labels = mask_preprocessing.get("labels", {})
for process_type in mask_preprocessing["selection"]:
process_label = mask_preprocessing_labels.get(process_type, None)
process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label
mask_preprocessing_choices.append( (process_label, process_type) )
- video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed")
+ video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed")
video_prompt_type_video_mask = gr.Dropdown(
mask_preprocessing_choices,
value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")),
@@ -7591,7 +7699,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
choices= image_ref_choices["choices"],
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
visible = image_ref_choices.get("visible", True),
- label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2
+ label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2
)
image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
@@ -7634,7 +7742,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#")
video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False)
with gr.Group():
- video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
+ video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") )
with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row:
video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value
video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")]
@@ -7649,14 +7757,14 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value )
image_refs_single_image_mode = model_def.get("one_image_ref_needed", False)
- image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "")
+ image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "")
image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode)
frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" )
image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs)
no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None
- background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects")
+ background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects")
remove_background_images_ref = gr.Dropdown(
choices=[
@@ -7664,7 +7772,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
(background_removal_label, 1),
],
value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1),
- label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
+ label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal
)
any_audio_voices_support = any_audio_track(base_model_type)
@@ -8084,7 +8192,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
("Aligned to the beginning of the First Window of the new Video Sample", "T"),
],
value=filter_letters(video_prompt_type_value, "T"),
- label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue",
+ label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue",
visible = vace or ltxv or t2v or infinitetalk
)
From e28c95ae912addee1abdb1b798bd0f554cd658cb Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 23 Sep 2025 23:04:44 +0200
Subject: [PATCH 030/155] Vace Contenders are in Town
---
README.md | 8 +
defaults/lucy_edit.json | 3 +-
defaults/lucy_edit_fastwan.json | 16 ++
models/flux/flux_handler.py | 2 +-
models/flux/flux_main.py | 14 +-
models/hyvideo/hunyuan.py | 14 +-
models/ltx_video/ltxv_handler.py | 2 +
models/qwen/qwen_handler.py | 2 +-
models/qwen/qwen_main.py | 12 +-
models/wan/animate/animate_utils.py | 143 ++++++++++
models/wan/animate/face_blocks.py | 382 +++++++++++++++++++++++++
models/wan/animate/model_animate.py | 31 ++
models/wan/animate/motion_encoder.py | 308 ++++++++++++++++++++
models/wan/any2video.py | 153 +++-------
models/wan/df_handler.py | 1 +
models/wan/wan_handler.py | 8 +-
shared/utils/utils.py | 102 ++++---
wgp.py | 406 ++++++++++-----------------
18 files changed, 1160 insertions(+), 447 deletions(-)
create mode 100644 defaults/lucy_edit_fastwan.json
create mode 100644 models/wan/animate/animate_utils.py
create mode 100644 models/wan/animate/face_blocks.py
create mode 100644 models/wan/animate/model_animate.py
create mode 100644 models/wan/animate/motion_encoder.py
diff --git a/README.md b/README.md
index 8c5e4367a..6c338ac7c 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
+### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
+
+So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
+- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
+
+- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
+
+
### September 15 2025: WanGP v8.6 - Attack of the Clones
- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json
index 6344dfffd..a8f67ad64 100644
--- a/defaults/lucy_edit.json
+++ b/defaults/lucy_edit.json
@@ -2,7 +2,7 @@
"model": {
"name": "Wan2.2 Lucy Edit 5B",
"architecture": "lucy_edit",
- "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
+ "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
@@ -10,6 +10,7 @@
],
"group": "wan2_2"
},
+ "prompt": "change the clothes to red",
"video_length": 81,
"guidance_scale": 5,
"flow_shift": 5,
diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json
new file mode 100644
index 000000000..c67c795d7
--- /dev/null
+++ b/defaults/lucy_edit_fastwan.json
@@ -0,0 +1,16 @@
+{
+ "model": {
+ "name": "Wan2.2 FastWan Lucy Edit 5B",
+ "architecture": "lucy_edit",
+ "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
+ "URLs": "lucy_edit",
+ "group": "wan2_2",
+ "loras": "ti2v_2_2"
+ },
+ "prompt": "change the clothes to red",
+ "video_length": 81,
+ "guidance_scale": 1,
+ "flow_shift": 3,
+ "num_inference_steps": 5,
+ "resolution": "1280x720"
+}
\ No newline at end of file
diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py
index d16888122..83de7c3fc 100644
--- a/models/flux/flux_handler.py
+++ b/models/flux/flux_handler.py
@@ -56,7 +56,7 @@ def query_model_def(base_model_type, model_def):
}
- extra_model_def["lock_image_refs_ratios"] = True
+ extra_model_def["fit_into_canvas_image_refs"] = 0
return extra_model_def
diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py
index 686371129..6746a2375 100644
--- a/models/flux/flux_main.py
+++ b/models/flux/flux_main.py
@@ -142,8 +142,8 @@ def generate(
n_prompt: str = None,
sampling_steps: int = 20,
input_ref_images = None,
- image_guide= None,
- image_mask= None,
+ input_frames= None,
+ input_masks= None,
width= 832,
height=480,
embedded_guidance_scale: float = 2.5,
@@ -197,10 +197,12 @@ def generate(
for new_img in input_ref_images[1:]:
stiched = stitch_images(stiched, new_img)
input_ref_images = [stiched]
- elif image_guide is not None:
- input_ref_images = [image_guide]
+ elif input_frames is not None:
+ input_ref_images = [convert_tensor_to_image(input_frames) ]
else:
input_ref_images = None
+ image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
+
if self.name in ['flux-dev-uso', 'flux-dev-umo'] :
inp, height, width = prepare_multi_ip(
@@ -253,8 +255,8 @@ def unpack_latent(x):
if image_mask is not None:
from shared.utils.utils import convert_image_to_tensor
img_msk_rebuilt = inp["img_msk_rebuilt"]
- img= convert_image_to_tensor(image_guide)
- x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
+ img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide)
+ x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt
x = x.clamp(-1, 1)
x = x.transpose(0, 1)
diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py
index 181a9a774..aa6c3b3e6 100644
--- a/models/hyvideo/hunyuan.py
+++ b/models/hyvideo/hunyuan.py
@@ -865,7 +865,7 @@ def generate(
freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
else:
if input_frames != None:
- target_height, target_width = input_frames.shape[-3:-1]
+ target_height, target_width = input_frames.shape[-2:]
elif input_video != None:
target_height, target_width = input_video.shape[-2:]
@@ -894,9 +894,10 @@ def generate(
pixel_value_bg = input_video.unsqueeze(0)
pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0)
if input_frames != None:
- pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float()
- pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
- pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
+ pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float()
+ # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.)
+ # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float()
+ pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0)
if input_video != None:
pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2)
pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2)
@@ -908,10 +909,11 @@ def generate(
if pixel_value_bg.shape[2] < frame_num:
padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:])
pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2)
- pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
+ # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
+ pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2)
bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample()
- pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.)
+ pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1
mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample()
bg_latents = torch.cat([bg_latents, mask_latents], dim=1)
bg_latents.mul_(self.vae.config.scaling_factor)
diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py
index 2845fdb0b..8c322e153 100644
--- a/models/ltx_video/ltxv_handler.py
+++ b/models/ltx_video/ltxv_handler.py
@@ -35,6 +35,8 @@ def query_model_def(base_model_type, model_def):
"selection": ["", "A", "NA", "XA", "XNA"],
}
+ extra_model_def["extra_control_frames"] = 1
+ extra_model_def["dont_cat_preguide"]= True
return extra_model_def
@staticmethod
diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py
index 010298ead..cc6a76484 100644
--- a/models/qwen/qwen_handler.py
+++ b/models/qwen/qwen_handler.py
@@ -17,7 +17,7 @@ def query_model_def(base_model_type, model_def):
("Default", "default"),
("Lightning", "lightning")],
"guidance_max_phases" : 1,
- "lock_image_refs_ratios": True,
+ "fit_into_canvas_image_refs": 0,
}
if base_model_type in ["qwen_image_edit_20B"]:
diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py
index cec0b8e05..abd5c5cfe 100644
--- a/models/qwen/qwen_main.py
+++ b/models/qwen/qwen_main.py
@@ -17,7 +17,7 @@
from diffusers import FlowMatchEulerDiscreteScheduler
from .pipeline_qwenimage import QwenImagePipeline
from PIL import Image
-from shared.utils.utils import calculate_new_dimensions
+from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image
def stitch_images(img1, img2):
# Resize img2 to match img1's height
@@ -103,8 +103,8 @@ def generate(
n_prompt = None,
sampling_steps: int = 20,
input_ref_images = None,
- image_guide= None,
- image_mask= None,
+ input_frames= None,
+ input_masks= None,
width= 832,
height=480,
guide_scale: float = 4,
@@ -179,8 +179,10 @@ def generate(
if n_prompt is None or len(n_prompt) == 0:
n_prompt= "text, watermark, copyright, blurry, low resolution"
- if image_guide is not None:
- input_ref_images = [image_guide]
+
+ image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
+ if input_frames is not None:
+ input_ref_images = [convert_tensor_to_image(input_frames) ]
elif input_ref_images is not None:
# image stiching method
stiched = input_ref_images[0]
diff --git a/models/wan/animate/animate_utils.py b/models/wan/animate/animate_utils.py
new file mode 100644
index 000000000..9474dce3e
--- /dev/null
+++ b/models/wan/animate/animate_utils.py
@@ -0,0 +1,143 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import torch
+import numbers
+from peft import LoraConfig
+
+
+def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"):
+ target_modules = []
+ for name, module in transformer.named_modules():
+ if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear):
+ target_modules.append(name)
+
+ transformer_lora_config = LoraConfig(
+ r=rank,
+ lora_alpha=alpha,
+ init_lora_weights=init_lora_weights,
+ target_modules=target_modules,
+ )
+ return transformer_lora_config
+
+
+
+class TensorList(object):
+
+ def __init__(self, tensors):
+ """
+ tensors: a list of torch.Tensor objects. No need to have uniform shape.
+ """
+ assert isinstance(tensors, (list, tuple))
+ assert all(isinstance(u, torch.Tensor) for u in tensors)
+ assert len(set([u.ndim for u in tensors])) == 1
+ assert len(set([u.dtype for u in tensors])) == 1
+ assert len(set([u.device for u in tensors])) == 1
+ self.tensors = tensors
+
+ def to(self, *args, **kwargs):
+ return TensorList([u.to(*args, **kwargs) for u in self.tensors])
+
+ def size(self, dim):
+ assert dim == 0, 'only support get the 0th size'
+ return len(self.tensors)
+
+ def pow(self, *args, **kwargs):
+ return TensorList([u.pow(*args, **kwargs) for u in self.tensors])
+
+ def squeeze(self, dim):
+ assert dim != 0
+ if dim > 0:
+ dim -= 1
+ return TensorList([u.squeeze(dim) for u in self.tensors])
+
+ def type(self, *args, **kwargs):
+ return TensorList([u.type(*args, **kwargs) for u in self.tensors])
+
+ def type_as(self, other):
+ assert isinstance(other, (torch.Tensor, TensorList))
+ if isinstance(other, torch.Tensor):
+ return TensorList([u.type_as(other) for u in self.tensors])
+ else:
+ return TensorList([u.type(other.dtype) for u in self.tensors])
+
+ @property
+ def dtype(self):
+ return self.tensors[0].dtype
+
+ @property
+ def device(self):
+ return self.tensors[0].device
+
+ @property
+ def ndim(self):
+ return 1 + self.tensors[0].ndim
+
+ def __getitem__(self, index):
+ return self.tensors[index]
+
+ def __len__(self):
+ return len(self.tensors)
+
+ def __add__(self, other):
+ return self._apply(other, lambda u, v: u + v)
+
+ def __radd__(self, other):
+ return self._apply(other, lambda u, v: v + u)
+
+ def __sub__(self, other):
+ return self._apply(other, lambda u, v: u - v)
+
+ def __rsub__(self, other):
+ return self._apply(other, lambda u, v: v - u)
+
+ def __mul__(self, other):
+ return self._apply(other, lambda u, v: u * v)
+
+ def __rmul__(self, other):
+ return self._apply(other, lambda u, v: v * u)
+
+ def __floordiv__(self, other):
+ return self._apply(other, lambda u, v: u // v)
+
+ def __truediv__(self, other):
+ return self._apply(other, lambda u, v: u / v)
+
+ def __rfloordiv__(self, other):
+ return self._apply(other, lambda u, v: v // u)
+
+ def __rtruediv__(self, other):
+ return self._apply(other, lambda u, v: v / u)
+
+ def __pow__(self, other):
+ return self._apply(other, lambda u, v: u ** v)
+
+ def __rpow__(self, other):
+ return self._apply(other, lambda u, v: v ** u)
+
+ def __neg__(self):
+ return TensorList([-u for u in self.tensors])
+
+ def __iter__(self):
+ for tensor in self.tensors:
+ yield tensor
+
+ def __repr__(self):
+ return 'TensorList: \n' + repr(self.tensors)
+
+ def _apply(self, other, op):
+ if isinstance(other, (list, tuple, TensorList)) or (
+ isinstance(other, torch.Tensor) and (
+ other.numel() > 1 or other.ndim > 1
+ )
+ ):
+ assert len(other) == len(self.tensors)
+ return TensorList([op(u, v) for u, v in zip(self.tensors, other)])
+ elif isinstance(other, numbers.Number) or (
+ isinstance(other, torch.Tensor) and (
+ other.numel() == 1 and other.ndim <= 1
+ )
+ ):
+ return TensorList([op(u, other) for u in self.tensors])
+ else:
+ raise TypeError(
+ f'unsupported operand for *: "TensorList" and "{type(other)}"'
+ )
\ No newline at end of file
diff --git a/models/wan/animate/face_blocks.py b/models/wan/animate/face_blocks.py
new file mode 100644
index 000000000..8ddb829fa
--- /dev/null
+++ b/models/wan/animate/face_blocks.py
@@ -0,0 +1,382 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+from torch import nn
+import torch
+from typing import Tuple, Optional
+from einops import rearrange
+import torch.nn.functional as F
+import math
+from shared.attention import pay_attention
+
+MEMORY_LAYOUT = {
+ "flash": (
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
+ lambda x: x,
+ ),
+ "torch": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+ "vanilla": (
+ lambda x: x.transpose(1, 2),
+ lambda x: x.transpose(1, 2),
+ ),
+}
+
+
+def attention(
+ q,
+ k,
+ v,
+ mode="torch",
+ drop_rate=0,
+ attn_mask=None,
+ causal=False,
+ max_seqlen_q=None,
+ batch_size=1,
+):
+ """
+ Perform QKV self attention.
+
+ Args:
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
+ drop_rate (float): Dropout rate in attention map. (default: 0)
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
+ (default: None)
+ causal (bool): Whether to use causal attention. (default: False)
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into q.
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
+ used to index into kv.
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
+
+ Returns:
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
+ """
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
+
+ if mode == "torch":
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(q.dtype)
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
+
+ elif mode == "flash":
+ x = flash_attn_func(
+ q,
+ k,
+ v,
+ )
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
+ elif mode == "vanilla":
+ scale_factor = 1 / math.sqrt(q.size(-1))
+
+ b, a, s, _ = q.shape
+ s1 = k.size(2)
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
+ if causal:
+ # Only applied to self attention
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
+ attn_bias.to(q.dtype)
+
+ if attn_mask is not None:
+ if attn_mask.dtype == torch.bool:
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
+ else:
+ attn_bias += attn_mask
+
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
+ attn += attn_bias
+ attn = attn.softmax(dim=-1)
+ attn = torch.dropout(attn, p=drop_rate, train=True)
+ x = attn @ v
+ else:
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
+
+ x = post_attn_layout(x)
+ b, s, a, d = x.shape
+ out = x.reshape(b, s, -1)
+ return out
+
+
+class CausalConv1d(nn.Module):
+
+ def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
+ super().__init__()
+
+ self.pad_mode = pad_mode
+ padding = (kernel_size - 1, 0) # T
+ self.time_causal_padding = padding
+
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(x)
+
+
+
+class FaceEncoder(nn.Module):
+ def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+
+ self.num_heads = num_heads
+ self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
+ self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+ self.act = nn.SiLU()
+ self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
+ self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
+
+ self.out_proj = nn.Linear(1024, hidden_dim)
+ self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
+
+ def forward(self, x):
+
+ x = rearrange(x, "b t c -> b c t")
+ b, c, t = x.shape
+
+ x = self.conv1_local(x)
+ x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
+
+ x = self.norm1(x)
+ x = self.act(x)
+ x = rearrange(x, "b t c -> b c t")
+ x = self.conv2(x)
+ x = rearrange(x, "b c t -> b t c")
+ x = self.norm2(x)
+ x = self.act(x)
+ x = rearrange(x, "b t c -> b c t")
+ x = self.conv3(x)
+ x = rearrange(x, "b c t -> b t c")
+ x = self.norm3(x)
+ x = self.act(x)
+ x = self.out_proj(x)
+ x = rearrange(x, "(b n) t c -> b t n c", b=b)
+ padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
+ x = torch.cat([x, padding], dim=-2)
+ x_local = x.clone()
+
+ return x_local
+
+
+
+class RMSNorm(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ elementwise_affine=True,
+ eps: float = 1e-6,
+ device=None,
+ dtype=None,
+ ):
+ """
+ Initialize the RMSNorm normalization layer.
+
+ Args:
+ dim (int): The dimension of the input tensor.
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+ Attributes:
+ eps (float): A small value added to the denominator for numerical stability.
+ weight (nn.Parameter): Learnable scaling parameter.
+
+ """
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ self.eps = eps
+ if elementwise_affine:
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
+
+ def _norm(self, x):
+ """
+ Apply the RMSNorm normalization to the input tensor.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The normalized tensor.
+
+ """
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+ def forward(self, x):
+ """
+ Forward pass through the RMSNorm layer.
+
+ Args:
+ x (torch.Tensor): The input tensor.
+
+ Returns:
+ torch.Tensor: The output tensor after applying RMSNorm.
+
+ """
+ output = self._norm(x.float()).type_as(x)
+ if hasattr(self, "weight"):
+ output = output * self.weight
+ return output
+
+
+def get_norm_layer(norm_layer):
+ """
+ Get the normalization layer.
+
+ Args:
+ norm_layer (str): The type of normalization layer.
+
+ Returns:
+ norm_layer (nn.Module): The normalization layer.
+ """
+ if norm_layer == "layer":
+ return nn.LayerNorm
+ elif norm_layer == "rms":
+ return RMSNorm
+ else:
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
+
+
+class FaceAdapter(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ heads_num: int,
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ num_adapter_layers: int = 1,
+ dtype=None,
+ device=None,
+ ):
+
+ factory_kwargs = {"dtype": dtype, "device": device}
+ super().__init__()
+ self.hidden_size = hidden_dim
+ self.heads_num = heads_num
+ self.fuser_blocks = nn.ModuleList(
+ [
+ FaceBlock(
+ self.hidden_size,
+ self.heads_num,
+ qk_norm=qk_norm,
+ qk_norm_type=qk_norm_type,
+ **factory_kwargs,
+ )
+ for _ in range(num_adapter_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ motion_embed: torch.Tensor,
+ idx: int,
+ freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
+ freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
+
+
+
+class FaceBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int,
+ heads_num: int,
+ qk_norm: bool = True,
+ qk_norm_type: str = "rms",
+ qk_scale: float = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+
+ self.deterministic = False
+ self.hidden_size = hidden_size
+ self.heads_num = heads_num
+ head_dim = hidden_size // heads_num
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
+ self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
+
+ self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
+
+ qk_norm_layer = get_norm_layer(qk_norm_type)
+ self.q_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+ self.k_norm = (
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
+ )
+
+ self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ motion_vec: torch.Tensor,
+ motion_mask: Optional[torch.Tensor] = None,
+ use_context_parallel=False,
+ ) -> torch.Tensor:
+
+ B, T, N, C = motion_vec.shape
+ T_comp = T
+
+ x_motion = self.pre_norm_motion(motion_vec)
+ x_feat = self.pre_norm_feat(x)
+
+ kv = self.linear1_kv(x_motion)
+ q = self.linear1_q(x_feat)
+
+ k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
+ q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
+
+ # Apply QK-Norm if needed.
+ q = self.q_norm(q).to(v)
+ k = self.k_norm(k).to(v)
+
+ k = rearrange(k, "B L N H D -> (B L) N H D")
+ v = rearrange(v, "B L N H D -> (B L) N H D")
+
+ if use_context_parallel:
+ q = gather_forward(q, dim=1)
+
+ q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
+ # Compute attention.
+ # Size([batches, tokens, heads, head_features])
+ qkv_list = [q, k, v]
+ del q,k,v
+ attn = pay_attention(qkv_list)
+ # attn = attention(
+ # q,
+ # k,
+ # v,
+ # max_seqlen_q=q.shape[1],
+ # batch_size=q.shape[0],
+ # )
+
+ attn = attn.reshape(*attn.shape[:2], -1)
+ attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
+ # if use_context_parallel:
+ # attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
+
+ output = self.linear2(attn)
+
+ if motion_mask is not None:
+ output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
+
+ return output
\ No newline at end of file
diff --git a/models/wan/animate/model_animate.py b/models/wan/animate/model_animate.py
new file mode 100644
index 000000000..d07f762ee
--- /dev/null
+++ b/models/wan/animate/model_animate.py
@@ -0,0 +1,31 @@
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import math
+import types
+from copy import deepcopy
+from einops import rearrange
+from typing import List
+import numpy as np
+import torch
+import torch.cuda.amp as amp
+import torch.nn as nn
+
+def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
+ pose_latents = self.pose_patch_embedding(pose_latents)
+ x[:, :, 1:] += pose_latents
+
+ b,c,T,h,w = face_pixel_values.shape
+ face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
+ encode_bs = 8
+ face_pixel_values_tmp = []
+ for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
+ face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
+
+ motion_vec = torch.cat(face_pixel_values_tmp)
+
+ motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
+ motion_vec = self.face_encoder(motion_vec)
+
+ B, L, H, C = motion_vec.shape
+ pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
+ motion_vec = torch.cat([pad_face, motion_vec], dim=1)
+ return x, motion_vec
diff --git a/models/wan/animate/motion_encoder.py b/models/wan/animate/motion_encoder.py
new file mode 100644
index 000000000..02b00408c
--- /dev/null
+++ b/models/wan/animate/motion_encoder.py
@@ -0,0 +1,308 @@
+# Modified from ``https://github.com/wyhsirius/LIA``
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+import math
+
+def custom_qr(input_tensor):
+ original_dtype = input_tensor.dtype
+ if original_dtype in [torch.bfloat16, torch.float16]:
+ q, r = torch.linalg.qr(input_tensor.to(torch.float32))
+ return q.to(original_dtype), r.to(original_dtype)
+ return torch.linalg.qr(input_tensor)
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
+ return F.leaky_relu(input + bias, negative_slope) * scale
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+ _, minor, in_h, in_w = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, minor, in_h, 1, in_w, 1)
+ out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
+ out = out.view(-1, minor, in_h * up_y, in_w * up_x)
+
+ out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
+ max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
+
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
+ return out[:, :, ::down_y, ::down_x]
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+
+
+def make_kernel(k):
+ k = torch.tensor(k, dtype=torch.float32)
+ if k.ndim == 1:
+ k = k[None, :] * k[:, None]
+ k /= k.sum()
+ return k
+
+
+class FusedLeakyReLU(nn.Module):
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
+ super().__init__()
+ self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+ return out
+
+
+class Blur(nn.Module):
+ def __init__(self, kernel, pad, upsample_factor=1):
+ super().__init__()
+
+ kernel = make_kernel(kernel)
+
+ if upsample_factor > 1:
+ kernel = kernel * (upsample_factor ** 2)
+
+ self.register_buffer('kernel', kernel)
+
+ self.pad = pad
+
+ def forward(self, input):
+ return upfirdn2d(input, self.kernel, pad=self.pad)
+
+
+class ScaledLeakyReLU(nn.Module):
+ def __init__(self, negative_slope=0.2):
+ super().__init__()
+
+ self.negative_slope = negative_slope
+
+ def forward(self, input):
+ return F.leaky_relu(input, negative_slope=self.negative_slope)
+
+
+class EqualConv2d(nn.Module):
+ def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
+ self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
+
+ self.stride = stride
+ self.padding = padding
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_channel))
+ else:
+ self.bias = None
+
+ def forward(self, input):
+
+ return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
+
+ def __repr__(self):
+ return (
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
+ )
+
+
+class EqualLinear(nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
+ super().__init__()
+
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
+
+ if bias:
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
+ else:
+ self.bias = None
+
+ self.activation = activation
+
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
+ self.lr_mul = lr_mul
+
+ def forward(self, input):
+
+ if self.activation:
+ out = F.linear(input, self.weight * self.scale)
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
+ else:
+ out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
+
+ return out
+
+ def __repr__(self):
+ return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
+
+
+class ConvLayer(nn.Sequential):
+ def __init__(
+ self,
+ in_channel,
+ out_channel,
+ kernel_size,
+ downsample=False,
+ blur_kernel=[1, 3, 3, 1],
+ bias=True,
+ activate=True,
+ ):
+ layers = []
+
+ if downsample:
+ factor = 2
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
+ pad0 = (p + 1) // 2
+ pad1 = p // 2
+
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
+
+ stride = 2
+ self.padding = 0
+
+ else:
+ stride = 1
+ self.padding = kernel_size // 2
+
+ layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
+ bias=bias and not activate))
+
+ if activate:
+ if bias:
+ layers.append(FusedLeakyReLU(out_channel))
+ else:
+ layers.append(ScaledLeakyReLU(0.2))
+
+ super().__init__(*layers)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
+ super().__init__()
+
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
+
+ self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
+
+ def forward(self, input):
+ out = self.conv1(input)
+ out = self.conv2(out)
+
+ skip = self.skip(input)
+ out = (out + skip) / math.sqrt(2)
+
+ return out
+
+
+class EncoderApp(nn.Module):
+ def __init__(self, size, w_dim=512):
+ super(EncoderApp, self).__init__()
+
+ channels = {
+ 4: 512,
+ 8: 512,
+ 16: 512,
+ 32: 512,
+ 64: 256,
+ 128: 128,
+ 256: 64,
+ 512: 32,
+ 1024: 16
+ }
+
+ self.w_dim = w_dim
+ log_size = int(math.log(size, 2))
+
+ self.convs = nn.ModuleList()
+ self.convs.append(ConvLayer(3, channels[size], 1))
+
+ in_channel = channels[size]
+ for i in range(log_size, 2, -1):
+ out_channel = channels[2 ** (i - 1)]
+ self.convs.append(ResBlock(in_channel, out_channel))
+ in_channel = out_channel
+
+ self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
+
+ def forward(self, x):
+
+ res = []
+ h = x
+ for conv in self.convs:
+ h = conv(h)
+ res.append(h)
+
+ return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
+
+
+class Encoder(nn.Module):
+ def __init__(self, size, dim=512, dim_motion=20):
+ super(Encoder, self).__init__()
+
+ # appearance netmork
+ self.net_app = EncoderApp(size, dim)
+
+ # motion network
+ fc = [EqualLinear(dim, dim)]
+ for i in range(3):
+ fc.append(EqualLinear(dim, dim))
+
+ fc.append(EqualLinear(dim, dim_motion))
+ self.fc = nn.Sequential(*fc)
+
+ def enc_app(self, x):
+ h_source = self.net_app(x)
+ return h_source
+
+ def enc_motion(self, x):
+ h, _ = self.net_app(x)
+ h_motion = self.fc(h)
+ return h_motion
+
+
+class Direction(nn.Module):
+ def __init__(self, motion_dim):
+ super(Direction, self).__init__()
+ self.weight = nn.Parameter(torch.randn(512, motion_dim))
+
+ def forward(self, input):
+
+ weight = self.weight + 1e-8
+ Q, R = custom_qr(weight)
+ if input is None:
+ return Q
+ else:
+ input_diag = torch.diag_embed(input) # alpha, diagonal matrix
+ out = torch.matmul(input_diag, Q.T)
+ out = torch.sum(out, dim=1)
+ return out
+
+
+class Synthesis(nn.Module):
+ def __init__(self, motion_dim):
+ super(Synthesis, self).__init__()
+ self.direction = Direction(motion_dim)
+
+
+class Generator(nn.Module):
+ def __init__(self, size, style_dim=512, motion_dim=20):
+ super().__init__()
+
+ self.enc = Encoder(size, style_dim, motion_dim)
+ self.dec = Synthesis(motion_dim)
+
+ def get_motion(self, img):
+ #motion_feat = self.enc.enc_motion(img)
+ # motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
+ with torch.cuda.amp.autocast(dtype=torch.float32):
+ motion_feat = self.enc.enc_motion(img)
+ motion = self.dec.direction(motion_feat)
+ return motion
\ No newline at end of file
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index c03752b58..41d6d6316 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -203,10 +203,7 @@ def __init__(
self.use_timestep_transform = True
def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None):
- if ref_images is None:
- ref_images = [None] * len(frames)
- else:
- assert len(frames) == len(ref_images)
+ ref_images = [ref_images] * len(frames)
if masks is None:
latents = self.vae.encode(frames, tile_size = tile_size)
@@ -238,11 +235,7 @@ def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, over
return cat_latents
def vace_encode_masks(self, masks, ref_images=None):
- if ref_images is None:
- ref_images = [None] * len(masks)
- else:
- assert len(masks) == len(ref_images)
-
+ ref_images = [ref_images] * len(masks)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
@@ -270,79 +263,6 @@ def vace_encode_masks(self, masks, ref_images=None):
result_masks.append(mask)
return result_masks
- def vace_latent(self, z, m):
- return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
-
-
- def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False):
- image_sizes = []
- trim_video_guide = len(keep_video_guide_frames)
- def conv_tensor(t, device):
- return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device)
-
- for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
- prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
- num_frames = total_frames - prepend_count
- num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames
- if sub_src_mask is not None and sub_src_video is not None:
- src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
- src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device)
- # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255])
- # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
- if prepend_count > 0:
- src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
- src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1)
- src_video_shape = src_video[i].shape
- if src_video_shape[1] != total_frames:
- src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1)
- image_sizes.append(src_video[i].shape[2:])
- elif sub_src_video is None:
- if prepend_count > 0:
- src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
- src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
- else:
- src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device)
- src_mask[i] = torch.ones_like(src_video[i], device=device)
- image_sizes.append(image_size)
- else:
- src_video[i] = conv_tensor(sub_src_video[:num_frames], device)
- src_mask[i] = torch.ones_like(src_video[i], device=device)
- if prepend_count > 0:
- src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
- src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
- src_video_shape = src_video[i].shape
- if src_video_shape[1] != total_frames:
- src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- image_sizes.append(src_video[i].shape[2:])
- for k, keep in enumerate(keep_video_guide_frames):
- if not keep:
- pos = prepend_count + k
- src_video[i][:, pos:pos+1] = 0
- src_mask[i][:, pos:pos+1] = 1
-
- for k, frame in enumerate(inject_frames):
- if frame != None:
- pos = prepend_count + k
- src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True)
-
-
- self.background_mask = None
- for i, ref_images in enumerate(src_ref_images):
- if ref_images is not None:
- image_size = image_sizes[i]
- for j, ref_img in enumerate(ref_images):
- if ref_img is not None and not torch.is_tensor(ref_img):
- if j==0 and any_background_ref:
- if self.background_mask == None: self.background_mask = [None] * len(src_ref_images)
- src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True)
- else:
- src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device)
- if self.background_mask != None:
- self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref
- return src_video, src_mask, src_ref_images
def get_vae_latents(self, ref_images, device, tile_size= 0):
ref_vae_latents = []
@@ -369,7 +289,9 @@ def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=No
def generate(self,
input_prompt,
input_frames= None,
+ input_frames2= None,
input_masks = None,
+ input_masks2 = None,
input_ref_images = None,
input_ref_masks = None,
input_faces = None,
@@ -615,21 +537,22 @@ def generate(self,
pose_pixels = input_frames * input_masks
input_masks = 1. - input_masks
pose_pixels -= input_masks
- save_video(pose_pixels, "pose.mp4")
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
input_frames = input_frames * input_masks
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
if prefix_frames_count > 0:
input_frames[:, :prefix_frames_count] = input_video
input_masks[:, :prefix_frames_count] = 1
- save_video(input_frames, "input_frames.mp4")
- save_video(input_masks, "input_masks.mp4", value_range=(0,1))
+ # save_video(pose_pixels, "pose.mp4")
+ # save_video(input_frames, "input_frames.mp4")
+ # save_video(input_masks, "input_masks.mp4", value_range=(0,1))
lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device)
msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device)
msk = torch.concat([msk_ref, msk_control], dim=1)
- clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device)
- lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1)
+ image_ref = input_ref_images[0].to(self.device)
+ clip_image_start = image_ref.squeeze(1)
+ lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
y = torch.concat([msk, lat_y])
kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
lat_y = msk = msk_control = msk_ref = pose_pixels = None
@@ -701,12 +624,11 @@ def generate(self,
# Phantom
if phantom:
- input_ref_images_neg = None
- if input_ref_images != None: # Phantom Ref images
- input_ref_images = self.get_vae_latents(input_ref_images, self.device)
- input_ref_images_neg = torch.zeros_like(input_ref_images)
- ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0
- trim_frames = input_ref_images.shape[1]
+ lat_input_ref_images_neg = None
+ if input_ref_images is not None: # Phantom Ref images
+ lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device)
+ lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images)
+ ref_images_count = trim_frames = lat_input_ref_images.shape[1]
if ti2v:
if input_video is None:
@@ -721,25 +643,23 @@ def generate(self,
# Vace
if vace :
# vace context encode
- input_frames = [u.to(self.device) for u in input_frames]
- input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
- input_masks = [u.to(self.device) for u in input_masks]
+ input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])
+ input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])
+ input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
+ input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
ref_images_before = True
- if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask]
z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents )
m0 = self.vace_encode_masks(input_masks, input_ref_images)
- if self.background_mask != None:
- color_reference_frame = input_ref_images[0][0].clone()
- zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size )
- mbg = self.vace_encode_masks(self.background_mask, None)
+ if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None:
+ color_reference_frame = input_ref_images[0].clone()
+ zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size )
+ mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None)
for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg):
zz0[:, 0:1] = zzbg
mm0[:, 0:1] = mmbg
-
- self.background_mask = zz0 = mm0 = zzbg = mmbg = None
- z = self.vace_latent(z0, m0)
-
- ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0
+ zz0 = mm0 = zzbg = mmbg = None
+ z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)]
+ ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0
context_scale = context_scale if context_scale != None else [1.0] * len(z)
kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count })
if overlapped_latents != None :
@@ -747,15 +667,8 @@ def generate(self,
extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0)
if prefix_frames_count > 0:
color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone()
-
- target_shape = list(z0[0].shape)
- target_shape[0] = int(target_shape[0] / 2)
- lat_h, lat_w = target_shape[-2:]
- height = self.vae_stride[1] * lat_h
- width = self.vae_stride[2] * lat_w
-
- else:
- target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2])
+ lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2]
+ target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w)
if multitalk:
if audio_proj is None:
@@ -860,7 +773,9 @@ def clear():
apg_norm_threshold = 55
text_momentumbuffer = MomentumBuffer(apg_momentum)
audio_momentumbuffer = MomentumBuffer(apg_momentum)
-
+ input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None
+ gc.collect()
+ torch.cuda.empty_cache()
# denoising
trans = self.model
@@ -878,7 +793,7 @@ def clear():
kwargs.update({"t": timestep, "current_step": start_step_no + i})
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
- if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step:
+ if denoising_strength < 1 and i <= injection_denoising_step:
sigma = t / 1000
noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if inject_from_start:
@@ -912,8 +827,8 @@ def clear():
any_guidance = guide_scale != 1
if phantom:
gen_args = {
- "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
- [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
+ "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 +
+ [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]),
"context": [context, context_null, context_null] ,
}
elif fantasy:
diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py
index 7a5d3ea2d..31d5da6c8 100644
--- a/models/wan/df_handler.py
+++ b/models/wan/df_handler.py
@@ -21,6 +21,7 @@ def query_model_def(base_model_type, model_def):
extra_model_def["fps"] =fps
extra_model_def["frames_minimum"] = 17
extra_model_def["frames_steps"] = 20
+ extra_model_def["latent_size"] = 4
extra_model_def["sliding_window"] = True
extra_model_def["skip_layer_guidance"] = True
extra_model_def["tea_cache"] = True
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index ff1570f3a..f2a8fb459 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -114,7 +114,6 @@ def query_model_def(base_model_type, model_def):
"tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels),
"mag_cache" : True,
"keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"],
- "convert_image_guide_to_video" : True,
"sample_solvers":[
("unipc", "unipc"),
("euler", "euler"),
@@ -175,6 +174,8 @@ def query_model_def(base_model_type, model_def):
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)"
extra_model_def["background_ref_outpainted"] = False
+ extra_model_def["return_image_refs_tensor"] = True
+ extra_model_def["guide_inpaint_color"] = 0
@@ -196,15 +197,15 @@ def query_model_def(base_model_type, model_def):
"letters_filter": "KFI",
}
- extra_model_def["lock_image_refs_ratios"] = True
extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames"
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["pad_guide_video"] = True
extra_model_def["guide_inpaint_color"] = 127.5
extra_model_def["forced_guide_mask_inputs"] = True
+ extra_model_def["return_image_refs_tensor"] = True
if base_model_type in ["standin"]:
- extra_model_def["lock_image_refs_ratios"] = True
+ extra_model_def["fit_into_canvas_image_refs"] = 0
extra_model_def["image_ref_choices"] = {
"choices": [
("No Reference Image", ""),
@@ -480,6 +481,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"video_prompt_type": "PVBXAKI",
"mask_expand": 20,
+ "audio_prompt_type_value": "R",
})
if text_oneframe_overlap(base_model_type):
diff --git a/shared/utils/utils.py b/shared/utils/utils.py
index 6e8a98bf2..bb2d5ffb6 100644
--- a/shared/utils/utils.py
+++ b/shared/utils/utils.py
@@ -32,6 +32,14 @@ def seed_everything(seed: int):
if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)
+def has_video_file_extension(filename):
+ extension = os.path.splitext(filename)[-1].lower()
+ return extension in [".mp4"]
+
+def has_image_file_extension(filename):
+ extension = os.path.splitext(filename)[-1].lower()
+ return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
+
def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ):
import math
@@ -94,7 +102,7 @@ def get_video_info(video_path):
return fps, width, height, frame_count
-def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor:
+def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor:
"""Extract nth frame from video as PyTorch tensor normalized to [-1, 1]."""
cap = cv2.VideoCapture(file_name)
@@ -102,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool
raise ValueError(f"Cannot open video: {file_name}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
-
+ fps = round(cap.get(cv2.CAP_PROP_FPS))
+ if target_fps is not None:
+ frame_no = round(target_fps * frame_no /fps)
+
# Handle out of bounds
if frame_no >= total_frames or frame_no < 0:
if return_last_if_missing:
@@ -175,10 +186,15 @@ def remove_background(img, session=None):
def convert_image_to_tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0)
-def convert_tensor_to_image(t, frame_no = 0):
+def convert_tensor_to_image(t, frame_no = 0, mask_levels = False):
if len(t.shape) == 4:
t = t[:, frame_no]
- return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
+ if t.shape[0]== 1:
+ t = t.expand(3,-1,-1)
+ if mask_levels:
+ return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy())
+ else:
+ return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy())
def save_image(tensor_image, name, frame_no = -1):
convert_tensor_to_image(tensor_image, frame_no).save(name)
@@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi
image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS)
return image, new_height, new_width
-def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ):
+def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ):
if rm_background:
session = new_session()
@@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
for i, img in enumerate(img_list):
width, height = img.size
resized_mask = None
- if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2:
+ if any_background_ref == 1 and i==0 or any_background_ref == 2:
if outpainting_dims is not None and background_ref_outpainted:
resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True)
elif img.size != (budget_width, budget_height):
@@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg
if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) :
# resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB')
- output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200,
+ if return_tensor:
+ output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1))
+ else:
+ output_list.append(resized_image)
output_mask_list.append(resized_mask)
return output_list, output_mask_list
@@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
return ref_img.to(device), canvas
-def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ):
+def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"):
src_videos, src_masks = [], []
- inpaint_color = guide_inpaint_color/127.5 - 1
- prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0
+ inpaint_color_compressed = guide_inpaint_color/127.5 - 1
+ prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0
for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)):
- src_video = src_mask = None
- if cur_video_guide is not None:
- src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w
- if cur_video_mask is not None and any_mask:
- src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w
- if pre_video_guide is not None and not extract_guide_from_window_start:
+ src_video, src_mask = cur_video_guide, cur_video_mask
+ if pre_video_guide is not None:
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
if any_mask:
src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
- if src_video is None:
- if any_guide_padding:
- src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu")
- if any_mask:
- src_mask = torch.zeros_like(src_video[0:1])
- elif src_video.shape[1] < current_video_length:
- if any_guide_padding:
- src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1)
- if cur_video_mask is not None and any_mask:
- src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
- else:
- new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
- src_video = src_video[:, :new_num_frames]
- if any_mask:
- src_mask = src_mask[:, :new_num_frames]
-
- for k, keep in enumerate(keep_video_guide_frames):
- if not keep:
- pos = prepend_count + k
- src_video[:, pos:pos+1] = inpaint_color
- src_mask[:, pos:pos+1] = 1
-
- for k, frame in enumerate(inject_frames):
- if frame != None:
- pos = prepend_count + k
- src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True)
+ if any_guide_padding:
+ if src_video is None:
+ src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device)
+ elif src_video.shape[1] < current_video_length:
+ src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1)
+ elif src_video is not None:
+ new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1
+ src_video = src_video[:, :new_num_frames]
+
+ if any_mask and src_video is not None:
+ if src_mask is None:
+ src_mask = torch.ones_like(src_video[:1])
+ elif src_mask.shape[1] < src_video.shape[1]:
+ src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1)
+ else:
+ src_mask = src_mask[:, :src_video.shape[1]]
+
+ if src_video is not None :
+ for k, keep in enumerate(keep_video_guide_frames):
+ if not keep:
+ pos = prepend_count + k
+ src_video[:, pos:pos+1] = inpaint_color_compressed
+ if any_mask: src_mask[:, pos:pos+1] = 1
+
+ for k, frame in enumerate(inject_frames):
+ if frame != None:
+ pos = prepend_count + k
+ src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask)
+ if any_mask: src_mask[:, pos:pos+1] = msk
src_videos.append(src_video)
src_masks.append(src_mask)
return src_videos, src_masks
diff --git a/wgp.py b/wgp.py
index 0e6a0b20e..6a7bca122 100644
--- a/wgp.py
+++ b/wgp.py
@@ -24,6 +24,7 @@
from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers
from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask
from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions
+from shared.utils.utils import has_video_file_extension, has_image_file_extension
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image
from shared.utils.audio_video import save_image_metadata, read_image_metadata
from shared.match_archi import match_nvidia_architecture
@@ -62,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.61"
+WanGP_version = "8.7"
settings_version = 2.35
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type):
mode_def = get_model_def(model_type)
frames_minimum = mode_def.get("frames_minimum", 5)
frames_steps = mode_def.get("frames_steps", 4)
- return frames_minimum, frames_steps
+ latent_size = mode_def.get("latent_size", frames_steps)
+ return frames_minimum, frames_steps, latent_size
def get_model_fps(model_type):
mode_def = get_model_def(model_type)
@@ -3459,7 +3461,7 @@ def check(src, cond):
if len(video_other_prompts) >0 :
values += [video_other_prompts]
labels += ["Other Prompts"]
- if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"):
+ if len(video_outpainting) >0:
values += [video_outpainting]
labels += ["Outpainting"]
video_sample_solver = configs.get("sample_solver", "")
@@ -3532,6 +3534,11 @@ def convert_image(image):
return cast(Image, ImageOps.exif_transpose(image))
def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'):
+ if isinstance(video_in, str) and has_image_file_extension(video_in):
+ video_in = Image.open(video_in)
+ if isinstance(video_in, Image.Image):
+ return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0)
+
from shared.utils.utils import resample
import decord
@@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color):
def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) :
if not items:
return []
- max_workers = 11
+
import concurrent.futures
start_time = time.time()
# print(f"Preprocessus:{process_type} started")
if process_type in ["prephase", "upsample"]:
if wrap_in_list :
items = [ [img] for img in items]
- with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
- futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
- results = [None] * len(items)
- for future in concurrent.futures.as_completed(futures):
- idx = futures[future]
- results[idx] = future.result()
+ if max_workers == 1:
+ results = [image_processor(img) for img in items]
+ else:
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
+ futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)}
+ results = [None] * len(items)
+ for future in concurrent.futures.as_completed(futures):
+ idx = futures[future]
+ results[idx] = future.result()
if wrap_in_list:
results = [ img[0] for img in results]
@@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis
return results
-def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127):
- frame_width, frame_height = input_image.size
-
- if fit_crop:
- input_image = rescale_and_crop(input_image, width, height)
- if input_mask is not None:
- input_mask = rescale_and_crop(input_mask, width, height)
- return input_image, input_mask
-
- if outpainting_dims != None:
- if fit_canvas != None:
- frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims)
- else:
- frame_height, frame_width = height, width
-
- if fit_canvas != None:
- height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size)
-
- if outpainting_dims != None:
- final_height, final_width = height, width
- height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1)
-
- if fit_canvas != None or outpainting_dims != None:
- input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS)
- if input_mask is not None:
- input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS)
-
- if expand_scale != 0 and input_mask is not None:
- kernel_size = abs(expand_scale)
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
- op_expand = cv2.dilate if expand_scale > 0 else cv2.erode
- input_mask = np.array(input_mask)
- input_mask = op_expand(input_mask, kernel, iterations=3)
- input_mask = Image.fromarray(input_mask)
-
- if outpainting_dims != None:
- inpaint_color = inpaint_color / 127.5-1
- image = convert_image_to_tensor(input_image)
- full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device)
- full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image
- input_image = convert_tensor_to_image(full_frame)
-
- if input_mask is not None:
- mask = convert_image_to_tensor(input_mask)
- full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device)
- full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask
- input_mask = convert_tensor_to_image(full_frame)
-
- return input_image, input_mask
def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512):
if not input_video_path or max_frames <= 0:
return None, None
@@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None)
return face_tensor
+def get_default_workers():
+ return os.cpu_count()/ 2
def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1):
@@ -3906,8 +3869,8 @@ def prep_prephase(frame_idx):
return (target_frame, frame, mask)
else:
return (target_frame, None, None)
-
- proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False)
+ max_workers = get_default_workers()
+ proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers)
proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists)
for frame_idx, frame_group in enumerate(proc_lists):
proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group
@@ -3916,11 +3879,11 @@ def prep_prephase(frame_idx):
mask_video = None
if preproc2 != None:
- proc_list2 = process_images_multithread(preproc2, proc_list, process_type2)
+ proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers)
#### to be finished ...or not
- proc_list = process_images_multithread(preproc, proc_list, process_type)
+ proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers)
if any_mask:
- proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask)
+ proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers)
else:
proc_list_outside = proc_mask = len(proc_list) * [None]
@@ -3938,7 +3901,7 @@ def prep_prephase(frame_idx):
full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device)
full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask
mask = full_frame
- masks.append(mask)
+ masks.append(mask[:, :, 0:1].clone())
else:
masked_frame = processed_img
@@ -3958,13 +3921,13 @@ def prep_prephase(frame_idx):
proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None
- if args.save_masks:
- from preprocessing.dwpose.pose import save_one_video
- saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
- save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
- if any_mask:
- saved_masks = [mask.cpu().numpy() for mask in masks ]
- save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
+ # if args.save_masks:
+ # from preprocessing.dwpose.pose import save_one_video
+ # saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ]
+ # save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None)
+ # if any_mask:
+ # saved_masks = [mask.cpu().numpy() for mask in masks ]
+ # save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None)
preproc = None
preproc_outside = None
gc.collect()
@@ -3972,8 +3935,10 @@ def prep_prephase(frame_idx):
if pad_frames > 0:
masked_frames = masked_frames[0] * pad_frames + masked_frames
if any_mask: masked_frames = masks[0] * pad_frames + masks
+ masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.)
+ masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None
- return torch.stack(masked_frames), torch.stack(masks) if any_mask else None
+ return masked_frames, masks
def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16):
@@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling):
frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ]
def upsample_frames(frame):
return resize_lanczos(frame, h, w).unsqueeze(1)
- sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1)
+ sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1)
frames_to_upsample = None
return sample
@@ -4609,17 +4574,13 @@ def remove_temp_filenames(temp_filenames_list):
batch_size = 1
temp_filenames_list = []
- convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False)
- if convert_image_guide_to_video:
- if image_guide is not None and isinstance(image_guide, Image.Image):
- video_guide = convert_image_to_video(image_guide)
- temp_filenames_list.append(video_guide)
- image_guide = None
+ if image_guide is not None and isinstance(image_guide, Image.Image):
+ video_guide = image_guide
+ image_guide = None
- if image_mask is not None and isinstance(image_mask, Image.Image):
- video_mask = convert_image_to_video(image_mask)
- temp_filenames_list.append(video_mask)
- image_mask = None
+ if image_mask is not None and isinstance(image_mask, Image.Image):
+ video_mask = image_mask
+ image_mask = None
if model_def.get("no_background_removal", False): remove_background_images_ref = 0
@@ -4711,22 +4672,12 @@ def remove_temp_filenames(temp_filenames_list):
device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576
guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5)
extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False)
- i2v = test_class_i2v(model_type)
- diffusion_forcing = "diffusion_forcing" in model_filename
- t2v = base_model_type in ["t2v"]
- ltxv = "ltxv" in model_filename
- vace = test_vace_module(base_model_type)
- hunyuan_t2v = "hunyuan_video_720" in model_filename
- hunyuan_i2v = "hunyuan_video_i2v" in model_filename
hunyuan_custom = "hunyuan_video_custom" in model_filename
hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename
hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename
hunyuan_avatar = "hunyuan_video_avatar" in model_filename
fantasy = base_model_type in ["fantasy"]
multitalk = model_def.get("multitalk_class", False)
- standin = model_def.get("standin_class", False)
- infinitetalk = base_model_type in ["infinitetalk"]
- animate = base_model_type in ["animate"]
if "B" in audio_prompt_type or "X" in audio_prompt_type:
from models.wan.multitalk.multitalk import parse_speakers_locations
@@ -4763,9 +4714,9 @@ def remove_temp_filenames(temp_filenames_list):
sliding_window_size = current_video_length
reuse_frames = 0
- _, latent_size = get_model_min_frames_and_step(model_type)
- if diffusion_forcing: latent_size = 4
+ _, _, latent_size = get_model_min_frames_and_step(model_type)
original_image_refs = image_refs
+ image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified
# image_refs = None
# nb_frames_positions= 0
# Output Video Ratio Priorities:
@@ -4889,6 +4840,7 @@ def remove_temp_filenames(temp_filenames_list):
initial_total_windows = 0
discard_last_frames = sliding_window_discard_last_frames
default_requested_frames_to_generate = current_video_length
+ nb_frames_positions = 0
if sliding_window:
initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames)
current_video_length = sliding_window_size
@@ -4907,7 +4859,7 @@ def remove_temp_filenames(temp_filenames_list):
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
- src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None
+ src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
@@ -4963,7 +4915,6 @@ def remove_temp_filenames(temp_filenames_list):
return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) )
refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {}
- src_ref_images, src_ref_masks = image_refs, None
image_start_tensor = image_end_tensor = None
if window_no == 1 and (video_source is not None or image_start is not None):
if image_start is not None:
@@ -5020,7 +4971,7 @@ def remove_temp_filenames(temp_filenames_list):
if len(pos) > 0:
if pos in ["L", "l"]:
cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length
- if cur_end_pos >= last_frame_no and not joker_used:
+ if cur_end_pos >= last_frame_no-1 and not joker_used:
joker_used = True
cur_end_pos = last_frame_no -1
project_window_no += 1
@@ -5036,141 +4987,53 @@ def remove_temp_filenames(temp_filenames_list):
frames_to_inject[pos] = image_refs[i]
+ video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
if video_guide is not None:
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0:
raise gr.Error(f"invalid keep frames {keep_frames_video_guide}")
guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame
+ extra_control_frames = model_def.get("extra_control_frames", 0)
+ if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames
+
keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else []
keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ]
guide_frames_extract_count = len(keep_frames_parsed)
+ # Extract Faces to video
if "B" in video_prompt_type:
send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")])
src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps)
if src_faces is not None and src_faces.shape[1] < current_video_length:
src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1)
- if vace or animate:
- video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
- context_scale = [ control_net_weight]
- if "V" in video_prompt_type:
- process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
- preprocess_type, preprocess_type2 = "raw", None
- for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
- if process_num == 0:
- preprocess_type = process_map_video_guide.get(process_letter, "raw")
- else:
- preprocess_type2 = process_map_video_guide.get(process_letter, None)
- status_info = "Extracting " + processes_names[preprocess_type]
- extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
- if len(extra_process_list) == 1:
- status_info += " and " + processes_names[extra_process_list[0]]
- elif len(extra_process_list) == 2:
- status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
- if preprocess_type2 is not None:
- context_scale = [ control_net_weight /2, control_net_weight2 /2]
- send_cmd("progress", [0, get_latest_status(state, status_info)])
- inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color
- video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color )
- if preprocess_type2 != None:
- video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 )
-
- if video_guide_processed != None:
- if sample_fit_canvas != None:
- image_size = video_guide_processed.shape[-3: -1]
- sample_fit_canvas = None
- refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy())
- if video_guide_processed2 != None:
- refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())]
- if video_mask_processed != None:
- refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy())
-
- frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
-
- if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)):
- any_mask = True
- any_guide_padding = model_def.get("pad_guide_video", False)
- from shared.utils.utils import prepare_video_guide_and_mask
- src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2],
- [video_mask_processed, video_mask_processed2],
- pre_video_guide, image_size, current_video_length, latent_size,
- any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start,
- keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
-
- src_video, src_video2 = src_videos
- src_mask, src_mask2 = src_masks
- if src_video is None:
- abort = True
- break
- if src_faces is not None:
- if src_faces.shape[1] < src_video.shape[1]:
- src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
- else:
- src_faces = src_faces[:, :src_video.shape[1]]
- if args.save_masks:
- save_video( src_video, "masked_frames.mp4", fps)
- if src_video2 is not None:
- save_video( src_video2, "masked_frames2.mp4", fps)
- if any_mask:
- save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
-
- elif ltxv:
- preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw")
- status_info = "Extracting " + processes_names[preprocess_type]
- send_cmd("progress", [0, get_latest_status(state, status_info)])
- # start one frame ealier to facilitate latents merging later
- src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size )
- if src_video != None:
- src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ]
- refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
- refresh_preview["video_mask"] = None
- src_video = src_video.permute(3, 0, 1, 2)
- src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w
- if sample_fit_canvas != None:
- image_size = src_video.shape[-2:]
- sample_fit_canvas = None
-
- elif hunyuan_custom_edit:
- if "P" in video_prompt_type:
- progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")]
- else:
- progress_args = [0, get_latest_status(state,"Extracting Video and Mask")]
-
- send_cmd("progress", progress_args)
- src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0)
- refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy())
- if src_mask != None:
- refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy())
-
- elif "R" in video_prompt_type: # sparse video to video
- src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True)
- src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size)
- refresh_preview["video_guide"] = src_image
- src_video = convert_image_to_tensor(src_image).unsqueeze(1)
- if sample_fit_canvas != None:
- image_size = src_video.shape[-2:]
- sample_fit_canvas = None
-
- else: # video to video
- video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size)
- if video_guide_processed is None:
- src_video = pre_video_guide
+ # Sparse Video to Video
+ sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None
+
+ # Generic Video Preprocessing
+ process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
+ preprocess_type, preprocess_type2 = "raw", None
+ for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
+ if process_num == 0:
+ preprocess_type = process_map_video_guide.get(process_letter, "raw")
else:
- if sample_fit_canvas != None:
- image_size = video_guide_processed.shape[-3: -1]
- sample_fit_canvas = None
- src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2)
- if pre_video_guide != None:
- src_video = torch.cat( [pre_video_guide, src_video], dim=1)
- elif image_guide is not None:
- new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims)
- if sample_fit_canvas is not None:
- image_size = (new_image_guide.size[1], new_image_guide.size[0])
+ preprocess_type2 = process_map_video_guide.get(process_letter, None)
+ status_info = "Extracting " + processes_names[preprocess_type]
+ extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask])
+ if len(extra_process_list) == 1:
+ status_info += " and " + processes_names[extra_process_list[0]]
+ elif len(extra_process_list) == 2:
+ status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]]
+ context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight]
+ if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)])
+ inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color
+ video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size )
+ if preprocess_type2 != None:
+ video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size )
+
+ if video_guide_processed is not None and sample_fit_canvas is not None:
+ image_size = video_guide_processed.shape[-2:]
sample_fit_canvas = None
- refresh_preview["image_guide"] = new_image_guide
- if new_image_mask is not None:
- refresh_preview["image_mask"] = new_image_mask
if window_no == 1 and image_refs is not None and len(image_refs) > 0:
if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) :
@@ -5192,45 +5055,68 @@ def remove_temp_filenames(temp_filenames_list):
image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0])
refresh_preview["image_refs"] = image_refs
- if len(image_refs) > nb_frames_positions:
+ if len(image_refs) > nb_frames_positions:
+ src_ref_images = image_refs[nb_frames_positions:]
if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
- # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested
- image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0],
+
+ src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0],
remove_background_images_ref > 0, any_background_ref,
- fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1,
+ fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1),
block_size=block_size,
outpainting_dims =outpainting_dims,
- background_ref_outpainted = model_def.get("background_ref_outpainted", True) )
- refresh_preview["image_refs"] = image_refs
-
+ background_ref_outpainted = model_def.get("background_ref_outpainted", True),
+ return_tensor= model_def.get("return_image_refs_tensor", False) )
+
- if vace :
- image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications
-
- src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2],
- [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2],
- [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy],
- current_video_length, image_size = image_size, device ="cpu",
- keep_video_guide_frames=keep_frames_parsed,
- pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide],
- inject_frames= frames_to_inject_parsed,
- outpainting_dims = outpainting_dims,
- any_background_ref = any_background_ref
- )
- if len(frames_to_inject_parsed) or any_background_ref:
- new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
- if any_background_ref:
- new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:]
+ frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
+ if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False):
+ any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False)
+ any_guide_padding = model_def.get("pad_guide_video", False)
+ from shared.utils.utils import prepare_video_guide_and_mask
+ src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]),
+ [video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]),
+ None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide,
+ image_size, current_video_length, latent_size,
+ any_mask, any_guide_padding, guide_inpaint_color,
+ keep_frames_parsed, frames_to_inject_parsed , outpainting_dims)
+ video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None
+ if len(src_videos) == 1:
+ src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None
+ else:
+ src_video, src_video2 = src_videos
+ src_mask, src_mask2 = src_masks
+ src_videos = src_masks = None
+ if src_video is None:
+ abort = True
+ break
+ if src_faces is not None:
+ if src_faces.shape[1] < src_video.shape[1]:
+ src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1)
else:
- new_image_refs += image_refs[nb_frames_positions:]
- refresh_preview["image_refs"] = new_image_refs
- new_image_refs = None
-
- if sample_fit_canvas != None:
- image_size = src_video[0].shape[-2:]
- sample_fit_canvas = None
-
+ src_faces = src_faces[:, :src_video.shape[1]]
+ if video_guide is not None or len(frames_to_inject_parsed) > 0:
+ if args.save_masks:
+ if src_video is not None: save_video( src_video, "masked_frames.mp4", fps)
+ if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps)
+ if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
+ if video_guide is not None:
+ preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
+ refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
+ if src_video2 is not None:
+ refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)]
+ if src_mask is not None and video_mask is not None:
+ refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True)
+
+ if src_ref_images is not None or nb_frames_positions:
+ if len(frames_to_inject_parsed):
+ new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject]
+ else:
+ new_image_refs = []
+ if src_ref_images is not None:
+ new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ]
+ refresh_preview["image_refs"] = new_image_refs
+ new_image_refs = None
if len(refresh_preview) > 0:
new_inputs= locals()
@@ -5339,8 +5225,6 @@ def set_header_text(txt):
pre_video_frame = pre_video_frame,
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
image_refs_relative_size = image_refs_relative_size,
- image_guide= new_image_guide,
- image_mask= new_image_mask,
outpainting_dims = outpainting_dims,
)
except Exception as e:
@@ -6320,8 +6204,11 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
pop += ["image_refs_relative_size"]
if not vace:
- pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"]
+ pop += ["frames_positions", "control_net_weight", "control_net_weight2"]
+ if model_def.get("video_guide_outpainting", None) is None:
+ pop += ["video_guide_outpainting"]
+
if not (vace or t2v):
pop += ["min_frames_if_references"]
@@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice):
choice = min(choice, len(file_list))
return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0)
-def has_video_file_extension(filename):
- extension = os.path.splitext(filename)[-1].lower()
- return extension in [".mp4"]
-
-def has_image_file_extension(filename):
- extension = os.path.splitext(filename)[-1].lower()
- return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"]
def add_videos_to_gallery(state, input_file_list, choice, files_to_load):
gen = get_gen_info(state)
if files_to_load == None:
@@ -7881,7 +7761,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
elif recammaster:
video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True)
else:
- min_frames, frames_step = get_model_min_frames_and_step(base_model_type)
+ min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type)
current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97)
@@ -8059,7 +7939,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)")
- with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row:
+ with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row:
gr.Markdown("You may transfer the existing audio tracks of a Control Video")
audio_prompt_type_remux = gr.Dropdown(
choices=[
@@ -8284,16 +8164,16 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info())
with gr.Row(**default_visibility) as video_buttons_row:
video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
- video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source)
+ video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video )
video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm")
with gr.Row(**default_visibility) as image_buttons_row:
video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm")
video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image )
video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image)
- video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False)
video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image)
+ video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image )
video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm")
with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab:
with gr.Group(elem_classes= "postprocess"):
From 356e10ce71dc969d11eec42f05a0dfdfada20ad4 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 23 Sep 2025 23:28:43 +0200
Subject: [PATCH 031/155] fixed default sound
---
models/wan/wan_handler.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index f2a8fb459..12ddfed45 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -481,7 +481,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
"video_prompt_type": "PVBXAKI",
"mask_expand": 20,
- "audio_prompt_type_value": "R",
+ "audio_prompt_type": "R",
})
if text_oneframe_overlap(base_model_type):
From 625b50aefd6cb48f09da5bb102fbe900acb9be84 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 24 Sep 2025 00:12:22 +0200
Subject: [PATCH 032/155] fixed lucy edit fast wan, lora missing
---
README.md | 2 ++
defaults/lucy_edit_fastwan.json | 2 +-
wgp.py | 2 +-
3 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 6c338ac7c..8cd8e9c6a 100644
--- a/README.md
+++ b/README.md
@@ -25,6 +25,8 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
+In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the Video Mask of the person from which you want to extract the motion.
+
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json
index c67c795d7..d5d47c814 100644
--- a/defaults/lucy_edit_fastwan.json
+++ b/defaults/lucy_edit_fastwan.json
@@ -5,7 +5,7 @@
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
"URLs": "lucy_edit",
"group": "wan2_2",
- "loras": "ti2v_2_2"
+ "loras": "ti2v_2_2_fastwan"
},
"prompt": "change the clothes to red",
"video_length": 81,
diff --git a/wgp.py b/wgp.py
index 6a7bca122..43f5e8006 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.7"
+WanGP_version = "8.71"
settings_version = 2.35
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
From 2cbcb9523eab5af691f814be399657856742ebe3 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 24 Sep 2025 08:16:24 +0200
Subject: [PATCH 033/155] fix sparse video error
---
README.md | 7 ++++---
docs/INSTALLATION.md | 2 +-
requirements.txt | 2 +-
wgp.py | 2 +-
4 files changed, 7 insertions(+), 6 deletions(-)
diff --git a/README.md b/README.md
index 8cd8e9c6a..8cee8346b 100644
--- a/README.md
+++ b/README.md
@@ -23,12 +23,13 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
-- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
+- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can either *Replace* a person in a Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX i2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
-In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the Video Mask of the person from which you want to extract the motion.
+In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion.
-- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
+- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
+*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
### September 15 2025: WanGP v8.6 - Attack of the Clones
diff --git a/docs/INSTALLATION.md b/docs/INSTALLATION.md
index 9f66422a1..361f26684 100644
--- a/docs/INSTALLATION.md
+++ b/docs/INSTALLATION.md
@@ -27,7 +27,7 @@ conda activate wan2gp
### Step 2: Install PyTorch
```shell
-# Install PyTorch 2.7.0 with CUDA 12.4
+# Install PyTorch 2.7.0 with CUDA 12.8
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
```
diff --git a/requirements.txt b/requirements.txt
index e6c85fda7..d6f75f74b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -52,7 +52,7 @@ matplotlib
# Utilities
ftfy
piexif
-pynvml
+nvidia-ml-py
misaki
# Optional / commented out
diff --git a/wgp.py b/wgp.py
index 43f5e8006..d48c7fa29 100644
--- a/wgp.py
+++ b/wgp.py
@@ -4859,7 +4859,7 @@ def remove_temp_filenames(temp_filenames_list):
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
- src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None
+ src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
From 0a58ef6981e3343cbca6b3e12380148de281d91f Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 24 Sep 2025 16:28:05 +0200
Subject: [PATCH 034/155] fixed sparse crash
---
README.md | 2 +-
wgp.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 6c338ac7c..6b005d8dc 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
-- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
+- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
diff --git a/wgp.py b/wgp.py
index 6a7bca122..11da9679d 100644
--- a/wgp.py
+++ b/wgp.py
@@ -4859,7 +4859,7 @@ def remove_temp_filenames(temp_filenames_list):
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
- src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None
+ src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
From 76f86c43ad12914626cf6d3ad623ba1f6f551640 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 24 Sep 2025 23:38:27 +0200
Subject: [PATCH 035/155] add qwen edit plus support
---
README.md | 6 +-
defaults/qwen_image_edit_plus_20B.json | 17 +++
models/flux/flux_handler.py | 2 +-
models/qwen/pipeline_qwenimage.py | 163 ++++++++++++++-----------
models/qwen/qwen_handler.py | 21 +++-
models/qwen/qwen_main.py | 21 ++--
preprocessing/matanyone/app.py | 4 +-
shared/utils/utils.py | 4 +-
wgp.py | 35 ++++--
9 files changed, 171 insertions(+), 102 deletions(-)
create mode 100644 defaults/qwen_image_edit_plus_20B.json
diff --git a/README.md b/README.md
index 8cee8346b..b8c92b6e2 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
-### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena !
+### September 24 2025: WanGP v8.72 - Here Are ~~Two~~Three New Contenders in the Vace Arena !
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can either *Replace* a person in a Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX i2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
@@ -29,7 +29,11 @@ In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone*
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
+Also because I wanted to spoil you:
+- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version.
+
*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
+*Update 8.72*: shadow drop of Qwen Edit Plus
### September 15 2025: WanGP v8.6 - Attack of the Clones
diff --git a/defaults/qwen_image_edit_plus_20B.json b/defaults/qwen_image_edit_plus_20B.json
new file mode 100644
index 000000000..e10deb24b
--- /dev/null
+++ b/defaults/qwen_image_edit_plus_20B.json
@@ -0,0 +1,17 @@
+{
+ "model": {
+ "name": "Qwen Image Edit Plus 20B",
+ "architecture": "qwen_image_edit_plus_20B",
+ "description": "Qwen Image Edit Plus is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. This model is optimized to combine multiple Subjects & Objects.",
+ "URLs": [
+ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_plus_20B_quanto_bf16_int8.safetensors"
+ ],
+ "preload_URLs": "qwen_image_edit_20B",
+ "attention": {
+ "<89": "sdpa"
+ }
+ },
+ "prompt": "add a hat",
+ "resolution": "1024x1024",
+ "batch_size": 1
+}
\ No newline at end of file
diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py
index 83de7c3fc..471c339d4 100644
--- a/models/flux/flux_handler.py
+++ b/models/flux/flux_handler.py
@@ -28,7 +28,7 @@ def query_model_def(base_model_type, model_def):
extra_model_def["any_image_refs_relative_size"] = True
extra_model_def["no_background_removal"] = True
extra_model_def["image_ref_choices"] = {
- "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
+ "choices":[("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"),
("Up to two Images are Style Images", "KIJ")],
"default": "KI",
"letters_filter": "KIJ",
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index 14728862d..85934b79f 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -200,7 +200,8 @@ def __init__(
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.tokenizer_max_length = 1024
if processor is not None:
- self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
+ # self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.prompt_template_encode_start_idx = 64
else:
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
@@ -232,6 +233,21 @@ def _get_qwen_prompt_embeds(
txt = [template.format(e) for e in prompt]
if self.processor is not None and image is not None:
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
+ if isinstance(image, list):
+ base_img_prompt = ""
+ for i, img in enumerate(image):
+ base_img_prompt += img_prompt_template.format(i + 1)
+ elif image is not None:
+ base_img_prompt = img_prompt_template.format(1)
+ else:
+ base_img_prompt = ""
+
+ template = self.prompt_template_encode
+
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+
model_inputs = self.processor(
text=txt,
images=image,
@@ -464,7 +480,7 @@ def disable_vae_tiling(self):
def prepare_latents(
self,
- image,
+ images,
batch_size,
num_channels_latents,
height,
@@ -482,24 +498,30 @@ def prepare_latents(
shape = (batch_size, num_channels_latents, 1, height, width)
image_latents = None
- if image is not None:
- image = image.to(device=device, dtype=dtype)
- if image.shape[1] != self.latent_channels:
- image_latents = self._encode_vae_image(image=image, generator=generator)
- else:
- image_latents = image
- if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
- # expand init_latents for batch_size
- additional_image_per_prompt = batch_size // image_latents.shape[0]
- image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
- elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
- raise ValueError(
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
- )
- else:
- image_latents = torch.cat([image_latents], dim=0)
+ if images is not None and len(images ) > 0:
+ if not isinstance(images, list):
+ images = [images]
+ all_image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
- image_latents = self._pack_latents(image_latents)
+ image_latents = self._pack_latents(image_latents)
+ all_image_latents.append(image_latents)
+ image_latents = torch.cat(all_image_latents, dim=1)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -568,6 +590,7 @@ def __call__(
joint_pass= True,
lora_inpaint = False,
outpainting_dims = None,
+ qwen_edit_plus = False,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -683,61 +706,54 @@ def __call__(
batch_size = prompt_embeds.shape[0]
device = "cuda"
- prompt_image = None
+ condition_images = []
+ vae_image_sizes = []
+ vae_images = []
image_mask_latents = None
- if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
- image = image[0] if isinstance(image, list) else image
- image_height, image_width = self.image_processor.get_default_height_width(image)
- aspect_ratio = image_width / image_height
- if False :
- _, image_width, image_height = min(
- (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
- )
- image_width = image_width // multiple_of * multiple_of
- image_height = image_height // multiple_of * multiple_of
- ref_height, ref_width = 1568, 672
-
- if image_mask is None:
- if height * width < ref_height * ref_width: ref_height , ref_width = height , width
- if image_height * image_width > ref_height * ref_width:
- image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
- if (image_width,image_height) != image.size:
- image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS)
- elif not lora_inpaint:
- # _, image_width, image_height = min(
- # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS
- # )
- image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of)
- # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of)
- height, width = image_height, image_width
- image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS))
- image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
- image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
- # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
- image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
- image_mask_latents = self._pack_latents(image_mask_latents)
-
- prompt_image = image
- if image.size != (image_width, image_height):
- image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS)
-
- image = convert_image_to_tensor(image)
- if lora_inpaint:
- image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1]
- image_mask_latents = None
- green = torch.tensor([-1.0, 1.0, -1.0]).to(image)
- green_image = green[:, None, None] .expand_as(image)
- image = torch.where(image_mask_rebuilt > 0, green_image, image)
- prompt_image = convert_tensor_to_image(image)
- image = image.unsqueeze(0).unsqueeze(2)
- # image.save("nnn.png")
+ ref_size = 1024
+ ref_text_encoder_size = 384 if qwen_edit_plus else 1024
+ if image is not None:
+ if not isinstance(image, list): image = [image]
+ if height * width < ref_size * ref_size: ref_size = round(math.sqrt(height * width))
+ for ref_no, img in enumerate(image):
+ image_width, image_height = img.size
+ any_mask = ref_no == 0 and image_mask is not None
+ if (image_height * image_width > ref_size * ref_size) and not any_mask:
+ vae_height, vae_width =calculate_new_dimensions(ref_size, ref_size, image_height, image_width, False, block_size=multiple_of)
+ else:
+ vae_height, vae_width = image_height, image_width
+ vae_width = vae_width // multiple_of * multiple_of
+ vae_height = vae_height // multiple_of * multiple_of
+ vae_image_sizes.append((vae_width, vae_height))
+ condition_height, condition_width =calculate_new_dimensions(ref_text_encoder_size, ref_text_encoder_size, image_height, image_width, False, block_size=multiple_of)
+ condition_images.append(img.resize((condition_width, condition_height), resample=Image.Resampling.LANCZOS) )
+ if img.size != (vae_width, vae_height):
+ img = img.resize((vae_width, vae_height), resample=Image.Resampling.LANCZOS)
+ if any_mask :
+ if lora_inpaint:
+ image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1]
+ img = convert_image_to_tensor(img)
+ green = torch.tensor([-1.0, 1.0, -1.0]).to(img)
+ green_image = green[:, None, None] .expand_as(img)
+ img = torch.where(image_mask_rebuilt > 0, green_image, img)
+ img = convert_tensor_to_image(img)
+ else:
+ image_mask_latents = convert_image_to_tensor(image_mask.resize((vae_width // 8, vae_height // 8), resample=Image.Resampling.LANCZOS))
+ image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1]
+ image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
+ # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png")
+ image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1)
+ image_mask_latents = self._pack_latents(image_mask_latents)
+ # img.save("nnn.png")
+ vae_images.append( convert_image_to_tensor(img).unsqueeze(0).unsqueeze(2) )
+
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
- image=prompt_image,
+ image=condition_images,
prompt=prompt,
prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask,
@@ -747,7 +763,7 @@ def __call__(
)
if do_true_cfg:
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
- image=prompt_image,
+ image=condition_images,
prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask,
@@ -763,7 +779,7 @@ def __call__(
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels // 4
latents, image_latents = self.prepare_latents(
- image,
+ vae_images,
batch_size * num_images_per_prompt,
num_channels_latents,
height,
@@ -779,7 +795,12 @@ def __call__(
img_shapes = [
[
(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
- (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
+ # (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2),
+ *[
+ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
+ for vae_width, vae_height in vae_image_sizes
+ ],
+
]
] * batch_size
else:
@@ -971,7 +992,7 @@ def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
latents = latents / latents_std + latents_mean
output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
if image_mask is not None and not lora_inpaint : #not (lora_inpaint and outpainting_dims is not None):
- output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt
+ output_image = vae_images[0].squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(vae_images[0] ) * image_mask_rebuilt
return output_image
diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py
index cc6a76484..4fcaa3ba8 100644
--- a/models/qwen/qwen_handler.py
+++ b/models/qwen/qwen_handler.py
@@ -20,7 +20,7 @@ def query_model_def(base_model_type, model_def):
"fit_into_canvas_image_refs": 0,
}
- if base_model_type in ["qwen_image_edit_20B"]:
+ if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
extra_model_def["inpaint_support"] = True
extra_model_def["image_ref_choices"] = {
"choices": [
@@ -42,11 +42,20 @@ def query_model_def(base_model_type, model_def):
"image_modes" : [2],
}
+ if base_model_type in ["qwen_image_edit_plus_20B"]:
+ extra_model_def["guide_preprocessing"] = {
+ "selection": ["", "PV", "SV", "CV"],
+ }
+
+ extra_model_def["mask_preprocessing"] = {
+ "selection": ["", "A"],
+ "visible": False,
+ }
return extra_model_def
@staticmethod
def query_supported_types():
- return ["qwen_image_20B", "qwen_image_edit_20B"]
+ return ["qwen_image_20B", "qwen_image_edit_20B", "qwen_image_edit_plus_20B"]
@staticmethod
def query_family_maps():
@@ -113,9 +122,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"denoising_strength" : 1.,
"model_mode" : 0,
})
+ elif base_model_type in ["qwen_image_edit_plus_20B"]:
+ ui_defaults.update({
+ "video_prompt_type": "I",
+ "denoising_strength" : 1.,
+ "model_mode" : 0,
+ })
def validate_generative_settings(base_model_type, model_def, inputs):
- if base_model_type in ["qwen_image_edit_20B"]:
+ if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
model_mode = inputs["model_mode"]
denoising_strength= inputs["denoising_strength"]
video_guide_outpainting= inputs["video_guide_outpainting"]
diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py
index abd5c5cfe..8c4514fa2 100644
--- a/models/qwen/qwen_main.py
+++ b/models/qwen/qwen_main.py
@@ -51,10 +51,10 @@ def __init__(
transformer_filename = model_filename[0]
processor = None
tokenizer = None
- if base_model_type == "qwen_image_edit_20B":
+ if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct"))
-
+ self.base_model_type = base_model_type
base_config_file = "configs/qwen_image_20B.json"
with open(base_config_file, 'r', encoding='utf-8') as f:
@@ -173,7 +173,7 @@ def generate(
self.vae.tile_latent_min_height = VAE_tile_size[1]
self.vae.tile_latent_min_width = VAE_tile_size[1]
-
+ qwen_edit_plus = self.base_model_type in ["qwen_image_edit_plus_20B"]
self.vae.enable_slicing()
# width, height = aspect_ratios["16:9"]
@@ -182,17 +182,19 @@ def generate(
image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True)
if input_frames is not None:
- input_ref_images = [convert_tensor_to_image(input_frames) ]
- elif input_ref_images is not None:
+ input_ref_images = [convert_tensor_to_image(input_frames) ] + ([] if input_ref_images is None else input_ref_images )
+
+ if input_ref_images is not None:
# image stiching method
stiched = input_ref_images[0]
if "K" in video_prompt_type :
w, h = input_ref_images[0].size
height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
- for new_img in input_ref_images[1:]:
- stiched = stitch_images(stiched, new_img)
- input_ref_images = [stiched]
+ if not qwen_edit_plus:
+ for new_img in input_ref_images[1:]:
+ stiched = stitch_images(stiched, new_img)
+ input_ref_images = [stiched]
image = self.pipeline(
prompt=input_prompt,
@@ -212,7 +214,8 @@ def generate(
generator=torch.Generator(device="cuda").manual_seed(seed),
lora_inpaint = image_mask is not None and model_mode == 1,
outpainting_dims = outpainting_dims,
- )
+ qwen_edit_plus = qwen_edit_plus,
+ )
if image is None: return None
return image.transpose(0, 1)
diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py
index d40811d51..df71b40a9 100644
--- a/preprocessing/matanyone/app.py
+++ b/preprocessing/matanyone/app.py
@@ -21,6 +21,7 @@
from .matanyone.inference.inference_core import InferenceCore
from .matanyone_wrapper import matanyone
from shared.utils.audio_video import save_video, save_image
+from mmgp import offload
arg_device = "cuda"
arg_sam_model_type="vit_h"
@@ -539,7 +540,7 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive
file_name = ".".join(file_name.split(".")[:-1])
from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files
- source_audio_tracks, audio_metadata = extract_audio_tracks(video_input)
+ source_audio_tracks, audio_metadata = extract_audio_tracks(video_input, verbose= offload.default_verboseLevel )
output_fg_path = f"./mask_outputs/{file_name}_fg.mp4"
output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4"
if len(source_audio_tracks) == 0:
@@ -679,7 +680,6 @@ def load_unload_models(selected):
}
# os.path.join('.')
- from mmgp import offload
# sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".")
sam_checkpoint = None
diff --git a/shared/utils/utils.py b/shared/utils/utils.py
index bb2d5ffb6..00178aad1 100644
--- a/shared/utils/utils.py
+++ b/shared/utils/utils.py
@@ -321,7 +321,7 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu
ref_width, ref_height = ref_img.size
if (ref_height, ref_width) == image_size and outpainting_dims == None:
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
- canvas = torch.zeros_like(ref_img) if return_mask else None
+ canvas = torch.zeros_like(ref_img[:1]) if return_mask else None
else:
if outpainting_dims != None:
final_height, final_width = image_size
@@ -374,7 +374,7 @@ def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, im
if pre_video_guide is not None:
src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1)
if any_mask:
- src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1)
+ src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1)
if any_guide_padding:
if src_video is None:
diff --git a/wgp.py b/wgp.py
index d48c7fa29..f92cf3334 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.71"
+WanGP_version = "8.72"
settings_version = 2.35
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -4987,7 +4987,7 @@ def remove_temp_filenames(temp_filenames_list):
frames_to_inject[pos] = image_refs[i]
- video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None
+ video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = sparse_video_image = None
if video_guide is not None:
keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate)
if len(error) > 0:
@@ -6566,11 +6566,11 @@ def switch_image_mode(state):
inpaint_support = model_def.get("inpaint_support", False)
if inpaint_support:
if image_mode == 1:
- video_prompt_type = del_in_sequence(video_prompt_type, "VAG")
+ video_prompt_type = del_in_sequence(video_prompt_type, "VAG" + all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, "KI")
elif image_mode == 2:
+ video_prompt_type = del_in_sequence(video_prompt_type, "KI" + all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, "VAG")
- video_prompt_type = del_in_sequence(video_prompt_type, "KI")
ui_defaults["video_prompt_type"] = video_prompt_type
return str(time.time())
@@ -6965,10 +6965,11 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
return video_prompt_type
+all_guide_processes ="PDESLCMUVB"
def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
old_video_prompt_type = video_prompt_type
- video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB")
+ video_prompt_type = del_in_sequence(video_prompt_type, all_guide_processes)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
model_type = state["model_type"]
@@ -6978,8 +6979,12 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt
image_outputs = image_mode > 0
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value )
-
- return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible)
+ mask_preprocessing = model_def.get("mask_preprocessing", None)
+ if mask_preprocessing is not None:
+ mask_selector_visible = mask_preprocessing.get("visible", True)
+ else:
+ mask_selector_visible = True
+ return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible)
def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode):
@@ -7391,7 +7396,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
any_start_image = any_end_image = any_reference_image = any_image_mask = False
v2i_switch_supported = (vace or t2v or standin) and not image_outputs
ti2v_2_2 = base_model_type in ["ti2v_2_2"]
- gallery_height = 350
+ gallery_height = 550
def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ):
with gr.Row(visible = visible) as gallery_row:
gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode )
@@ -7459,8 +7464,12 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" )
any_control_video = any_control_image = False
- guide_preprocessing = model_def.get("guide_preprocessing", None)
- mask_preprocessing = model_def.get("mask_preprocessing", None)
+ if image_mode_value ==2:
+ guide_preprocessing = { "selection": ["V", "VG"]}
+ mask_preprocessing = { "selection": ["A"]}
+ else:
+ guide_preprocessing = model_def.get("guide_preprocessing", None)
+ mask_preprocessing = model_def.get("mask_preprocessing", None)
guide_custom_choices = model_def.get("guide_custom_choices", None)
image_ref_choices = model_def.get("image_ref_choices", None)
@@ -7504,7 +7513,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image")
video_prompt_type_video_guide = gr.Dropdown(
guide_preprocessing_choices,
- value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ),
+ value=filter_letters(video_prompt_type_value, all_guide_processes, guide_preprocessing.get("default", "") ),
label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
)
any_control_video = True
@@ -7590,8 +7599,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
if image_guide_value is None:
image_mask_guide_value = None
else:
- image_mask_value = rgb_bw_to_rgba_mask(image_mask_value)
- image_mask_guide_value = { "background" : image_guide_value, "composite" : None, "layers": [image_mask_value] }
+ image_mask_guide_value = { "background" : image_guide_value, "composite" : None}
+ image_mask_guide_value["layers"] = [] if image_mask_value is None else [rgb_bw_to_rgba_mask(image_mask_value)]
image_mask_guide = gr.ImageEditor(
label="Control Image to be Inpainted" if image_mode_value == 2 else "Control Image and Mask",
From b44fcc21df1cb0157381433d66c7fb1adeb77735 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Wed, 24 Sep 2025 23:45:44 +0200
Subject: [PATCH 036/155] restored original gallery height
---
wgp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/wgp.py b/wgp.py
index f92cf3334..18ccd3441 100644
--- a/wgp.py
+++ b/wgp.py
@@ -7396,7 +7396,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
any_start_image = any_end_image = any_reference_image = any_image_mask = False
v2i_switch_supported = (vace or t2v or standin) and not image_outputs
ti2v_2_2 = base_model_type in ["ti2v_2_2"]
- gallery_height = 550
+ gallery_height = 350
def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ):
with gr.Row(visible = visible) as gallery_row:
gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode )
From 654bb0483bf7d1bf32ba0c56e1392ddd761a014f Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 25 Sep 2025 00:50:03 +0200
Subject: [PATCH 037/155] fixed missing audio input for Hunyuan Avatar and
Custom Audio
---
models/hyvideo/hunyuan_handler.py | 10 +++++++++-
wgp.py | 2 +-
2 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py
index 435234102..ebe07d26d 100644
--- a/models/hyvideo/hunyuan_handler.py
+++ b/models/hyvideo/hunyuan_handler.py
@@ -173,8 +173,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
video_prompt_type = video_prompt_type.replace("M","")
ui_defaults["video_prompt_type"] = video_prompt_type
+ if settings_version < 2.36:
+ if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]:
+ audio_prompt_type= ui_defaults["audio_prompt_type"]
+ if "A" not in audio_prompt_type:
+ audio_prompt_type += "A"
+ ui_defaults["audio_prompt_type"] = audio_prompt_type
+
- pass
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
@@ -197,6 +203,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"guidance_scale": 7.5,
"flow_shift": 13,
"video_prompt_type": "I",
+ "audio_prompt_type": "A",
})
elif base_model_type in ["hunyuan_custom_edit"]:
ui_defaults.update({
@@ -213,4 +220,5 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"skip_steps_start_step_perc": 25,
"video_length": 129,
"video_prompt_type": "KI",
+ "audio_prompt_type": "A",
})
diff --git a/wgp.py b/wgp.py
index 18ccd3441..e7264d295 100644
--- a/wgp.py
+++ b/wgp.py
@@ -64,7 +64,7 @@
target_mmgp_version = "3.6.0"
WanGP_version = "8.72"
-settings_version = 2.35
+settings_version = 2.36
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
From ee0bb89ee94618eac574f1290dd45532b3b59220 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 25 Sep 2025 02:16:57 +0200
Subject: [PATCH 038/155] Added Qwen Preview mode
---
README.md | 3 +-
models/qwen/pipeline_qwenimage.py | 5 +--
models/qwen/qwen_handler.py | 7 ++++
models/wan/any2video.py | 54 ++++++++++++++-----------------
models/wan/wan_handler.py | 4 ++-
shared/RGB_factors.py | 4 +--
6 files changed, 41 insertions(+), 36 deletions(-)
diff --git a/README.md b/README.md
index 30a97e66b..9b8f70d45 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
-### September 24 2025: WanGP v8.72 - Here Are ~~Two~~Three New Contenders in the Vace Arena !
+### September 25 2025: WanGP v8.73 - Here Are ~~Two~~Three New Contenders in the Vace Arena !
So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
@@ -34,6 +34,7 @@ Also because I wanted to spoil you:
*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
*Update 8.72*: shadow drop of Qwen Edit Plus
+*Update 8.73*: Qwen Preview & InfiniteTalk Start image
### September 15 2025: WanGP v8.6 - Attack of the Clones
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index 85934b79f..134cc51b5 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -971,8 +971,9 @@ def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
latents = latents.to(latents_dtype)
if callback is not None:
- # preview = unpack_latent(img).transpose(0,1)
- callback(i, None, False)
+ preview = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ preview = preview.squeeze(0)
+ callback(i, preview, False)
self._current_timestep = None
diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py
index 4fcaa3ba8..99864a5a7 100644
--- a/models/qwen/qwen_handler.py
+++ b/models/qwen/qwen_handler.py
@@ -129,6 +129,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"model_mode" : 0,
})
+ @staticmethod
def validate_generative_settings(base_model_type, model_def, inputs):
if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]:
model_mode = inputs["model_mode"]
@@ -141,3 +142,9 @@ def validate_generative_settings(base_model_type, model_def, inputs):
gr.Info("Denoising Strength will be ignored while using Lora Inpainting")
if outpainting_dims is not None and model_mode == 0 :
return "Outpainting is not supported with Masked Denoising "
+
+ @staticmethod
+ def get_rgb_factors(base_model_type ):
+ from shared.RGB_factors import get_rgb_factors
+ latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("qwen")
+ return latent_rgb_factors, latent_rgb_factors_bias
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 41d6d6316..6b4ae629d 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -443,38 +443,32 @@ def generate(self,
# image2video
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False
- if image_start is None:
- if infinitetalk:
- new_shot = "Q" in video_prompt_type
- if input_frames is not None:
- image_ref = input_frames[:, 0]
- else:
- if input_ref_images is None:
- if pre_video_frame is None: raise Exception("Missing Reference Image")
- input_ref_images, new_shot = [pre_video_frame], False
- new_shot = new_shot and window_no <= len(input_ref_images)
- image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
- if new_shot or input_video is None:
- input_video = image_ref.unsqueeze(1)
- else:
- color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
- _ , preframes_count, height, width = input_video.shape
- input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
- if infinitetalk:
- image_start = image_ref.to(input_video)
- control_pre_frames_count = 1
- control_video = image_start.unsqueeze(1)
+ if infinitetalk:
+ new_shot = "Q" in video_prompt_type
+ if input_frames is not None:
+ image_ref = input_frames[:, 0]
else:
- image_start = input_video[:, -1]
- control_pre_frames_count = preframes_count
- control_video = input_video
-
- color_reference_frame = image_start.unsqueeze(1).clone()
+ if input_ref_images is None:
+ if pre_video_frame is None: raise Exception("Missing Reference Image")
+ input_ref_images, new_shot = [pre_video_frame], False
+ new_shot = new_shot and window_no <= len(input_ref_images)
+ image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ])
+ if new_shot or input_video is None:
+ input_video = image_ref.unsqueeze(1)
+ else:
+ color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot
+ _ , preframes_count, height, width = input_video.shape
+ input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype)
+ if infinitetalk:
+ image_start = image_ref.to(input_video)
+ control_pre_frames_count = 1
+ control_video = image_start.unsqueeze(1)
else:
- preframes_count = control_pre_frames_count = 1
- height, width = image_start.shape[1:]
- control_video = image_start.unsqueeze(1).to(self.device)
- color_reference_frame = control_video.clone()
+ image_start = input_video[:, -1]
+ control_pre_frames_count = preframes_count
+ control_video = input_video
+
+ color_reference_frame = image_start.unsqueeze(1).clone()
any_end_frame = image_end is not None
add_frames_for_end_image = any_end_frame and model_type == "i2v"
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 12ddfed45..574c990a0 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -245,8 +245,10 @@ def query_model_def(base_model_type, model_def):
"visible" : False,
}
- if vace_class or base_model_type in ["infinitetalk", "animate"]:
+ if vace_class or base_model_type in ["animate"]:
image_prompt_types_allowed = "TVL"
+ elif base_model_type in ["infinitetalk"]:
+ image_prompt_types_allowed = "TSVL"
elif base_model_type in ["ti2v_2_2"]:
image_prompt_types_allowed = "TSVL"
elif base_model_type in ["lucy_edit"]:
diff --git a/shared/RGB_factors.py b/shared/RGB_factors.py
index 6e865fa7b..8a870b4b6 100644
--- a/shared/RGB_factors.py
+++ b/shared/RGB_factors.py
@@ -1,6 +1,6 @@
# thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py)
def get_rgb_factors(model_family, model_type = None):
- if model_family == "wan":
+ if model_family in ["wan", "qwen"]:
if model_type =="ti2v_2_2":
latent_channels = 48
latent_dimensions = 3
@@ -261,7 +261,7 @@ def get_rgb_factors(model_family, model_type = None):
[ 0.0249, -0.0469, -0.1703]
]
- latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
+ latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
else:
latent_rgb_factors_bias = latent_rgb_factors = None
return latent_rgb_factors, latent_rgb_factors_bias
\ No newline at end of file
From 5cfedca74480099e11c8a5fc6a69c1e8abe701cc Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 25 Sep 2025 02:17:59 +0200
Subject: [PATCH 039/155] Added Qwen Preview mode
---
wgp.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/wgp.py b/wgp.py
index e7264d295..5267ca35b 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.72"
+WanGP_version = "8.73"
settings_version = 2.36
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
From 14d68bbc9152f35c00a55750cc15327680a03dd5 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 25 Sep 2025 09:40:51 +0200
Subject: [PATCH 040/155] fixed double Vace controlnets with no mask
---
models/wan/diffusion_forcing copy.py | 479 ------------------
models/wan/text2video fuse attempt.py | 698 --------------------------
wgp.py | 11 +-
3 files changed, 7 insertions(+), 1181 deletions(-)
delete mode 100644 models/wan/diffusion_forcing copy.py
delete mode 100644 models/wan/text2video fuse attempt.py
diff --git a/models/wan/diffusion_forcing copy.py b/models/wan/diffusion_forcing copy.py
deleted file mode 100644
index 753fd4561..000000000
--- a/models/wan/diffusion_forcing copy.py
+++ /dev/null
@@ -1,479 +0,0 @@
-import math
-import os
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import logging
-import numpy as np
-import torch
-from diffusers.image_processor import PipelineImageInput
-from diffusers.utils.torch_utils import randn_tensor
-from diffusers.video_processor import VideoProcessor
-from tqdm import tqdm
-from .modules.model import WanModel
-from .modules.t5 import T5EncoderModel
-from .modules.vae import WanVAE
-from wan.modules.posemb_layers import get_rotary_pos_embed
-from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
- get_sampling_sigmas, retrieve_timesteps)
-from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
-
-class DTT2V:
-
-
- def __init__(
- self,
- config,
- checkpoint_dir,
- rank=0,
- model_filename = None,
- text_encoder_filename = None,
- quantizeTransformer = False,
- dtype = torch.bfloat16,
- ):
- self.device = torch.device(f"cuda")
- self.config = config
- self.rank = rank
- self.dtype = dtype
- self.num_train_timesteps = config.num_train_timesteps
- self.param_dtype = config.param_dtype
-
- self.text_encoder = T5EncoderModel(
- text_len=config.text_len,
- dtype=config.t5_dtype,
- device=torch.device('cpu'),
- checkpoint_path=text_encoder_filename,
- tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
- shard_fn= None)
-
- self.vae_stride = config.vae_stride
- self.patch_size = config.patch_size
-
-
- self.vae = WanVAE(
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
- device=self.device)
-
- logging.info(f"Creating WanModel from {model_filename}")
- from mmgp import offload
-
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json")
- # offload.load_model_data(self.model, "recam.ckpt")
- # self.model.cpu()
- # offload.save_model(self.model, "recam.safetensors")
- if self.dtype == torch.float16 and not "fp16" in model_filename:
- self.model.to(self.dtype)
- # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
- if self.dtype == torch.float16:
- self.vae.model.to(self.dtype)
- self.model.eval().requires_grad_(False)
-
- self.scheduler = FlowUniPCMultistepScheduler()
-
- @property
- def do_classifier_free_guidance(self) -> bool:
- return self._guidance_scale > 1
-
- def encode_image(
- self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-
- # prefix_video
- prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1)
- prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1)
- if prefix_video.dtype == torch.uint8:
- prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0
- prefix_video = prefix_video.to(self.device)
- prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)]
- if prefix_video[0].shape[1] % causal_block_size != 0:
- truncate_len = prefix_video[0].shape[1] % causal_block_size
- print("the length of prefix video is truncated for the casual block size alignment.")
- prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
- predix_video_latent_length = prefix_video[0].shape[1]
- return prefix_video, predix_video_latent_length
-
- def prepare_latents(
- self,
- shape: Tuple[int],
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- ) -> torch.Tensor:
- return randn_tensor(shape, generator, device=device, dtype=dtype)
-
- def generate_timestep_matrix(
- self,
- num_frames,
- step_template,
- base_num_frames,
- ar_step=5,
- num_pre_ready=0,
- casual_block_size=1,
- shrink_interval_with_mask=False,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
- step_matrix, step_index = [], []
- update_mask, valid_interval = [], []
- num_iterations = len(step_template) + 1
- num_frames_block = num_frames // casual_block_size
- base_num_frames_block = base_num_frames // casual_block_size
- if base_num_frames_block < num_frames_block:
- infer_step_num = len(step_template)
- gen_block = base_num_frames_block
- min_ar_step = infer_step_num / gen_block
- assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
- # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
- step_template = torch.cat(
- [
- torch.tensor([999], dtype=torch.int64, device=step_template.device),
- step_template.long(),
- torch.tensor([0], dtype=torch.int64, device=step_template.device),
- ]
- ) # to handle the counter in row works starting from 1
- pre_row = torch.zeros(num_frames_block, dtype=torch.long)
- if num_pre_ready > 0:
- pre_row[: num_pre_ready // casual_block_size] = num_iterations
-
- while torch.all(pre_row >= (num_iterations - 1)) == False:
- new_row = torch.zeros(num_frames_block, dtype=torch.long)
- for i in range(num_frames_block):
- if i == 0 or pre_row[i - 1] >= (
- num_iterations - 1
- ): # the first frame or the last frame is completely denoised
- new_row[i] = pre_row[i] + 1
- else:
- new_row[i] = new_row[i - 1] - ar_step
- new_row = new_row.clamp(0, num_iterations)
-
- update_mask.append(
- (new_row != pre_row) & (new_row != num_iterations)
- ) # False: no need to update, True: need to update
- step_index.append(new_row)
- step_matrix.append(step_template[new_row])
- pre_row = new_row
-
- # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
- terminal_flag = base_num_frames_block
- if shrink_interval_with_mask:
- idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
- update_mask = update_mask[0]
- update_mask_idx = idx_sequence[update_mask]
- last_update_idx = update_mask_idx[-1].item()
- terminal_flag = last_update_idx + 1
- # for i in range(0, len(update_mask)):
- for curr_mask in update_mask:
- if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
- terminal_flag += 1
- valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
-
- step_update_mask = torch.stack(update_mask, dim=0)
- step_index = torch.stack(step_index, dim=0)
- step_matrix = torch.stack(step_matrix, dim=0)
-
- if casual_block_size > 1:
- step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
- step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
- step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
- valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
-
- return step_matrix, step_index, step_update_mask, valid_interval
-
- @torch.no_grad()
- def generate(
- self,
- prompt: Union[str, List[str]],
- negative_prompt: Union[str, List[str]] = "",
- image: PipelineImageInput = None,
- height: int = 480,
- width: int = 832,
- num_frames: int = 97,
- num_inference_steps: int = 50,
- shift: float = 1.0,
- guidance_scale: float = 5.0,
- seed: float = 0.0,
- overlap_history: int = 17,
- addnoise_condition: int = 0,
- base_num_frames: int = 97,
- ar_step: int = 5,
- causal_block_size: int = 1,
- causal_attention: bool = False,
- fps: int = 24,
- VAE_tile_size = 0,
- joint_pass = False,
- callback = None,
- ):
- generator = torch.Generator(device=self.device)
- generator.manual_seed(seed)
- # if base_num_frames > base_num_frames:
- # causal_block_size = 0
- self._guidance_scale = guidance_scale
-
- i2v_extra_kwrags = {}
- prefix_video = None
- predix_video_latent_length = 0
- if image:
- frame_width, frame_height = image.size
- scale = min(height / frame_height, width / frame_width)
- height = (int(frame_height * scale) // 16) * 16
- width = (int(frame_width * scale) // 16) * 16
-
- prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size)
-
- latent_length = (num_frames - 1) // 4 + 1
- latent_height = height // 8
- latent_width = width // 8
-
- prompt_embeds = self.text_encoder([prompt], self.device)
- prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds]
- if self.do_classifier_free_guidance:
- negative_prompt_embeds = self.text_encoder([negative_prompt], self.device)
- negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds]
-
-
-
- self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
- init_timesteps = self.scheduler.timesteps
- fps_embeds = [fps] * prompt_embeds[0].shape[0]
- fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
- transformer_dtype = self.dtype
- # with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad():
- if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames:
- # short video generation
- latent_shape = [16, latent_length, latent_height, latent_width]
- latents = self.prepare_latents(
- latent_shape, dtype=torch.float32, device=self.device, generator=generator
- )
- latents = [latents]
- if prefix_video is not None:
- latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
- base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
- step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
- latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size
- )
- sample_schedulers = []
- for _ in range(latent_length):
- sample_scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
- )
- sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
- sample_schedulers.append(sample_scheduler)
- sample_schedulers_counter = [0] * latent_length
-
- if callback != None:
- callback(-1, None, True)
-
- freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False)
- for i, timestep_i in enumerate(tqdm(step_matrix)):
- update_mask_i = step_update_mask[i]
- valid_interval_i = valid_interval[i]
- valid_interval_start, valid_interval_end = valid_interval_i
- timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
- latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
- if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
- noise_factor = 0.001 * addnoise_condition
- timestep_for_noised_condition = addnoise_condition
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor)
- + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length])
- * noise_factor
- )
- timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
- kwrags = {
- "x" : torch.stack([latent_model_input[0]]),
- "t" : timestep,
- "freqs" :freqs,
- "fps" : fps_embeds,
- # "causal_block_size" : causal_block_size,
- "callback" : callback,
- "pipeline" : self
- }
- kwrags.update(i2v_extra_kwrags)
-
-
- if not self.do_classifier_free_guidance:
- noise_pred = self.model(
- context=prompt_embeds,
- **kwrags,
- )[0]
- if self._interrupt:
- return None
- noise_pred= noise_pred.to(torch.float32)
- else:
- if joint_pass:
- noise_pred_cond, noise_pred_uncond = self.model(
- context=prompt_embeds,
- context2=negative_prompt_embeds,
- **kwrags,
- )
- if self._interrupt:
- return None
- else:
- noise_pred_cond = self.model(
- context=prompt_embeds,
- **kwrags,
- )[0]
- if self._interrupt:
- return None
- noise_pred_uncond = self.model(
- context=negative_prompt_embeds,
- **kwrags,
- )[0]
- if self._interrupt:
- return None
- noise_pred_cond= noise_pred_cond.to(torch.float32)
- noise_pred_uncond= noise_pred_uncond.to(torch.float32)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
- del noise_pred_cond, noise_pred_uncond
- for idx in range(valid_interval_start, valid_interval_end):
- if update_mask_i[idx].item():
- latents[0][:, idx] = sample_schedulers[idx].step(
- noise_pred[:, idx - valid_interval_start],
- timestep_i[idx],
- latents[0][:, idx],
- return_dict=False,
- generator=generator,
- )[0]
- sample_schedulers_counter[idx] += 1
- if callback is not None:
- callback(i, latents[0], False)
-
- x0 = latents[0].unsqueeze(0)
- videos = self.vae.decode(x0, tile_size= VAE_tile_size)
- videos = (videos / 2 + 0.5).clamp(0, 1)
- videos = [video for video in videos]
- videos = [video.permute(1, 2, 3, 0) * 255 for video in videos]
- videos = [video.cpu().numpy().astype(np.uint8) for video in videos]
- return videos
- else:
- # long video generation
- base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
- overlap_history_frames = (overlap_history - 1) // 4 + 1
- n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
- print(f"n_iter:{n_iter}")
- output_video = None
- for i in range(n_iter):
- if output_video is not None: # i !=0
- prefix_video = output_video[:, -overlap_history:].to(self.device)
- prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)]
- if prefix_video[0].shape[1] % causal_block_size != 0:
- truncate_len = prefix_video[0].shape[1] % causal_block_size
- print("the length of prefix video is truncated for the casual block size alignment.")
- prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len]
- predix_video_latent_length = prefix_video[0].shape[1]
- finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
- left_frame_num = latent_length - finished_frame_num
- base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
- else: # i == 0
- base_num_frames_iter = base_num_frames
- latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
- latents = self.prepare_latents(
- latent_shape, dtype=torch.float32, device=self.device, generator=generator
- )
- latents = [latents]
- if prefix_video is not None:
- latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
- step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
- base_num_frames_iter,
- init_timesteps,
- base_num_frames_iter,
- ar_step,
- predix_video_latent_length,
- causal_block_size,
- )
- sample_schedulers = []
- for _ in range(base_num_frames_iter):
- sample_scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
- )
- sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift)
- sample_schedulers.append(sample_scheduler)
- sample_schedulers_counter = [0] * base_num_frames_iter
- if callback != None:
- callback(-1, None, True)
-
- freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False)
- for i, timestep_i in enumerate(tqdm(step_matrix)):
- update_mask_i = step_update_mask[i]
- valid_interval_i = valid_interval[i]
- valid_interval_start, valid_interval_end = valid_interval_i
- timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
- latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
- if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
- noise_factor = 0.001 * addnoise_condition
- timestep_for_noised_condition = addnoise_condition
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
- * (1.0 - noise_factor)
- + torch.randn_like(
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
- )
- * noise_factor
- )
- timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
- kwrags = {
- "x" : torch.stack([latent_model_input[0]]),
- "t" : timestep,
- "freqs" :freqs,
- "fps" : fps_embeds,
- "causal_block_size" : causal_block_size,
- "causal_attention" : causal_attention,
- "callback" : callback,
- "pipeline" : self
- }
- kwrags.update(i2v_extra_kwrags)
-
- if not self.do_classifier_free_guidance:
- noise_pred = self.model(
- context=prompt_embeds,
- **kwrags,
- )[0]
- if self._interrupt:
- return None
- noise_pred= noise_pred.to(torch.float32)
- else:
- if joint_pass:
- noise_pred_cond, noise_pred_uncond = self.model(
- context=prompt_embeds,
- context2=negative_prompt_embeds,
- **kwrags,
- )
- if self._interrupt:
- return None
- else:
- noise_pred_cond = self.model(
- context=prompt_embeds,
- **kwrags,
- )[0]
- if self._interrupt:
- return None
- noise_pred_uncond = self.model(
- context=negative_prompt_embeds,
- )[0]
- if self._interrupt:
- return None
- noise_pred_cond= noise_pred_cond.to(torch.float32)
- noise_pred_uncond= noise_pred_uncond.to(torch.float32)
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
- del noise_pred_cond, noise_pred_uncond
- for idx in range(valid_interval_start, valid_interval_end):
- if update_mask_i[idx].item():
- latents[0][:, idx] = sample_schedulers[idx].step(
- noise_pred[:, idx - valid_interval_start],
- timestep_i[idx],
- latents[0][:, idx],
- return_dict=False,
- generator=generator,
- )[0]
- sample_schedulers_counter[idx] += 1
- if callback is not None:
- callback(i, latents[0].squeeze(0), False)
-
- x0 = latents[0].unsqueeze(0)
- videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]]
- if output_video is None:
- output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w
- else:
- output_video = torch.cat(
- [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1
- ) # c, f, h, w
- return output_video
diff --git a/models/wan/text2video fuse attempt.py b/models/wan/text2video fuse attempt.py
deleted file mode 100644
index 8af94583d..000000000
--- a/models/wan/text2video fuse attempt.py
+++ /dev/null
@@ -1,698 +0,0 @@
-# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
-import gc
-import logging
-import math
-import os
-import random
-import sys
-import types
-from contextlib import contextmanager
-from functools import partial
-from mmgp import offload
-import torch
-import torch.nn as nn
-import torch.cuda.amp as amp
-import torch.distributed as dist
-from tqdm import tqdm
-from PIL import Image
-import torchvision.transforms.functional as TF
-import torch.nn.functional as F
-from .distributed.fsdp import shard_model
-from .modules.model import WanModel
-from .modules.t5 import T5EncoderModel
-from .modules.vae import WanVAE
-from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
- get_sampling_sigmas, retrieve_timesteps)
-from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
-from wan.modules.posemb_layers import get_rotary_pos_embed
-from .utils.vace_preprocessor import VaceVideoProcessor
-
-
-def optimized_scale(positive_flat, negative_flat):
-
- # Calculate dot production
- dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
-
- # Squared norm of uncondition
- squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
-
- # st_star = v_cond^T * v_uncond / ||v_uncond||^2
- st_star = dot_product / squared_norm
-
- return st_star
-
-
-class WanT2V:
-
- def __init__(
- self,
- config,
- checkpoint_dir,
- rank=0,
- model_filename = None,
- text_encoder_filename = None,
- quantizeTransformer = False,
- dtype = torch.bfloat16
- ):
- self.device = torch.device(f"cuda")
- self.config = config
- self.rank = rank
- self.dtype = dtype
- self.num_train_timesteps = config.num_train_timesteps
- self.param_dtype = config.param_dtype
-
- self.text_encoder = T5EncoderModel(
- text_len=config.text_len,
- dtype=config.t5_dtype,
- device=torch.device('cpu'),
- checkpoint_path=text_encoder_filename,
- tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
- shard_fn= None)
-
- self.vae_stride = config.vae_stride
- self.patch_size = config.patch_size
-
-
- self.vae = WanVAE(
- vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
- device=self.device)
-
- logging.info(f"Creating WanModel from {model_filename}")
- from mmgp import offload
-
- self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False)
- # offload.load_model_data(self.model, "recam.ckpt")
- # self.model.cpu()
- # offload.save_model(self.model, "recam.safetensors")
- if self.dtype == torch.float16 and not "fp16" in model_filename:
- self.model.to(self.dtype)
- # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True)
- if self.dtype == torch.float16:
- self.vae.model.to(self.dtype)
- self.model.eval().requires_grad_(False)
-
-
- self.sample_neg_prompt = config.sample_neg_prompt
-
- if "Vace" in model_filename:
- self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]),
- min_area=480*832,
- max_area=480*832,
- min_fps=config.sample_fps,
- max_fps=config.sample_fps,
- zero_start=True,
- seq_len=32760,
- keep_last=True)
-
- self.adapt_vace_model()
-
- self.scheduler = FlowUniPCMultistepScheduler()
-
- def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0):
- if ref_images is None:
- ref_images = [None] * len(frames)
- else:
- assert len(frames) == len(ref_images)
-
- if masks is None:
- latents = self.vae.encode(frames, tile_size = tile_size)
- else:
- inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
- reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
- inactive = self.vae.encode(inactive, tile_size = tile_size)
- reactive = self.vae.encode(reactive, tile_size = tile_size)
- latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
-
- cat_latents = []
- for latent, refs in zip(latents, ref_images):
- if refs is not None:
- if masks is None:
- ref_latent = self.vae.encode(refs, tile_size = tile_size)
- else:
- ref_latent = self.vae.encode(refs, tile_size = tile_size)
- ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
- assert all([x.shape[1] == 1 for x in ref_latent])
- latent = torch.cat([*ref_latent, latent], dim=1)
- cat_latents.append(latent)
- return cat_latents
-
- def vace_encode_masks(self, masks, ref_images=None):
- if ref_images is None:
- ref_images = [None] * len(masks)
- else:
- assert len(masks) == len(ref_images)
-
- result_masks = []
- for mask, refs in zip(masks, ref_images):
- c, depth, height, width = mask.shape
- new_depth = int((depth + 3) // self.vae_stride[0])
- height = 2 * (int(height) // (self.vae_stride[1] * 2))
- width = 2 * (int(width) // (self.vae_stride[2] * 2))
-
- # reshape
- mask = mask[0, :, :, :]
- mask = mask.view(
- depth, height, self.vae_stride[1], width, self.vae_stride[1]
- ) # depth, height, 8, width, 8
- mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
- mask = mask.reshape(
- self.vae_stride[1] * self.vae_stride[2], depth, height, width
- ) # 8*8, depth, height, width
-
- # interpolation
- mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0)
-
- if refs is not None:
- length = len(refs)
- mask_pad = torch.zeros_like(mask[:, :length, :, :])
- mask = torch.cat((mask_pad, mask), dim=1)
- result_masks.append(mask)
- return result_masks
-
- def vace_latent(self, z, m):
- return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
-
- def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None):
- image_sizes = []
- trim_video = len(keep_frames)
-
- for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)):
- prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1]
- num_frames = total_frames - prepend_count
- if sub_src_mask is not None and sub_src_video is not None:
- src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
- # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
- # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
- src_video[i] = src_video[i].to(device)
- src_mask[i] = src_mask[i].to(device)
- if prepend_count > 0:
- src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
- src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
- src_video_shape = src_video[i].shape
- if src_video_shape[1] != total_frames:
- src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
- image_sizes.append(src_video[i].shape[2:])
- elif sub_src_video is None:
- if prepend_count > 0:
- src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1)
- src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1)
- else:
- src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
- src_mask[i] = torch.ones_like(src_video[i], device=device)
- image_sizes.append(image_size)
- else:
- src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame)
- src_video[i] = src_video[i].to(device)
- src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
- if prepend_count > 0:
- src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1)
- src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1)
- src_video_shape = src_video[i].shape
- if src_video_shape[1] != total_frames:
- src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
- image_sizes.append(src_video[i].shape[2:])
- for k, keep in enumerate(keep_frames):
- if not keep:
- src_video[i][:, k:k+1] = 0
- src_mask[i][:, k:k+1] = 1
-
- for i, ref_images in enumerate(src_ref_images):
- if ref_images is not None:
- image_size = image_sizes[i]
- for j, ref_img in enumerate(ref_images):
- if ref_img is not None:
- ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
- if ref_img.shape[-2:] != image_size:
- canvas_height, canvas_width = image_size
- ref_height, ref_width = ref_img.shape[-2:]
- white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
- scale = min(canvas_height / ref_height, canvas_width / ref_width)
- new_height = int(ref_height * scale)
- new_width = int(ref_width * scale)
- resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1)
- top = (canvas_height - new_height) // 2
- left = (canvas_width - new_width) // 2
- white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image
- ref_img = white_canvas
- src_ref_images[i][j] = ref_img.to(device)
- return src_video, src_mask, src_ref_images
-
- def decode_latent(self, zs, ref_images=None, tile_size= 0 ):
- if ref_images is None:
- ref_images = [None] * len(zs)
- else:
- assert len(zs) == len(ref_images)
-
- trimed_zs = []
- for z, refs in zip(zs, ref_images):
- if refs is not None:
- z = z[:, len(refs):, :, :]
- trimed_zs.append(z)
-
- return self.vae.decode(trimed_zs, tile_size= tile_size)
-
- def generate_timestep_matrix(
- self,
- num_frames,
- step_template,
- base_num_frames,
- ar_step=5,
- num_pre_ready=0,
- casual_block_size=1,
- shrink_interval_with_mask=False,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
- step_matrix, step_index = [], []
- update_mask, valid_interval = [], []
- num_iterations = len(step_template) + 1
- num_frames_block = num_frames // casual_block_size
- base_num_frames_block = base_num_frames // casual_block_size
- if base_num_frames_block < num_frames_block:
- infer_step_num = len(step_template)
- gen_block = base_num_frames_block
- min_ar_step = infer_step_num / gen_block
- assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
- # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
- step_template = torch.cat(
- [
- torch.tensor([999], dtype=torch.int64, device=step_template.device),
- step_template.long(),
- torch.tensor([0], dtype=torch.int64, device=step_template.device),
- ]
- ) # to handle the counter in row works starting from 1
- pre_row = torch.zeros(num_frames_block, dtype=torch.long)
- if num_pre_ready > 0:
- pre_row[: num_pre_ready // casual_block_size] = num_iterations
-
- while torch.all(pre_row >= (num_iterations - 1)) == False:
- new_row = torch.zeros(num_frames_block, dtype=torch.long)
- for i in range(num_frames_block):
- if i == 0 or pre_row[i - 1] >= (
- num_iterations - 1
- ): # the first frame or the last frame is completely denoised
- new_row[i] = pre_row[i] + 1
- else:
- new_row[i] = new_row[i - 1] - ar_step
- new_row = new_row.clamp(0, num_iterations)
-
- update_mask.append(
- (new_row != pre_row) & (new_row != num_iterations)
- ) # False: no need to update, True: need to update
- step_index.append(new_row)
- step_matrix.append(step_template[new_row])
- pre_row = new_row
-
- # for long video we split into several sequences, base_num_frames is set to the model max length (for training)
- terminal_flag = base_num_frames_block
- if shrink_interval_with_mask:
- idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
- update_mask = update_mask[0]
- update_mask_idx = idx_sequence[update_mask]
- last_update_idx = update_mask_idx[-1].item()
- terminal_flag = last_update_idx + 1
- # for i in range(0, len(update_mask)):
- for curr_mask in update_mask:
- if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
- terminal_flag += 1
- valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
-
- step_update_mask = torch.stack(update_mask, dim=0)
- step_index = torch.stack(step_index, dim=0)
- step_matrix = torch.stack(step_matrix, dim=0)
-
- if casual_block_size > 1:
- step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
- step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
- step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
- valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
-
- return step_matrix, step_index, step_update_mask, valid_interval
-
- def generate(self,
- input_prompt,
- input_frames= None,
- input_masks = None,
- input_ref_images = None,
- source_video=None,
- target_camera=None,
- context_scale=1.0,
- size=(1280, 720),
- frame_num=81,
- shift=5.0,
- sample_solver='unipc',
- sampling_steps=50,
- guide_scale=5.0,
- n_prompt="",
- seed=-1,
- offload_model=True,
- callback = None,
- enable_RIFLEx = None,
- VAE_tile_size = 0,
- joint_pass = False,
- slg_layers = None,
- slg_start = 0.0,
- slg_end = 1.0,
- cfg_star_switch = True,
- cfg_zero_step = 5,
- ):
- r"""
- Generates video frames from text prompt using diffusion process.
-
- Args:
- input_prompt (`str`):
- Text prompt for content generation
- size (tupele[`int`], *optional*, defaults to (1280,720)):
- Controls video resolution, (width,height).
- frame_num (`int`, *optional*, defaults to 81):
- How many frames to sample from a video. The number should be 4n+1
- shift (`float`, *optional*, defaults to 5.0):
- Noise schedule shift parameter. Affects temporal dynamics
- sample_solver (`str`, *optional*, defaults to 'unipc'):
- Solver used to sample the video.
- sampling_steps (`int`, *optional*, defaults to 40):
- Number of diffusion sampling steps. Higher values improve quality but slow generation
- guide_scale (`float`, *optional*, defaults 5.0):
- Classifier-free guidance scale. Controls prompt adherence vs. creativity
- n_prompt (`str`, *optional*, defaults to ""):
- Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
- seed (`int`, *optional*, defaults to -1):
- Random seed for noise generation. If -1, use random seed.
- offload_model (`bool`, *optional*, defaults to True):
- If True, offloads models to CPU during generation to save VRAM
-
- Returns:
- torch.Tensor:
- Generated video frames tensor. Dimensions: (C, N H, W) where:
- - C: Color channels (3 for RGB)
- - N: Number of frames (81)
- - H: Frame height (from size)
- - W: Frame width from size)
- """
- # preprocess
-
- if n_prompt == "":
- n_prompt = self.sample_neg_prompt
- seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
- seed_g = torch.Generator(device=self.device)
- seed_g.manual_seed(seed)
-
- frame_num = max(17, frame_num) # must match causal_block_size for value of 5
- frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 )
- num_frames = frame_num
- addnoise_condition = 20
- causal_attention = True
- fps = 16
- ar_step = 5
-
-
-
- context = self.text_encoder([input_prompt], self.device)
- context_null = self.text_encoder([n_prompt], self.device)
- if target_camera != None:
- size = (source_video.shape[2], source_video.shape[1])
- source_video = source_video.to(dtype=self.dtype , device=self.device)
- source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.)
- source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device)
- del source_video
- # Process target camera (recammaster)
- from wan.utils.cammmaster_tools import get_camera_embedding
- cam_emb = get_camera_embedding(target_camera)
- cam_emb = cam_emb.to(dtype=self.dtype, device=self.device)
-
- if input_frames != None:
- # vace context encode
- input_frames = [u.to(self.device) for u in input_frames]
- input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images]
- input_masks = [u.to(self.device) for u in input_masks]
-
- z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size)
- m0 = self.vace_encode_masks(input_masks, input_ref_images)
- z = self.vace_latent(z0, m0)
-
- target_shape = list(z0[0].shape)
- target_shape[0] = int(target_shape[0] / 2)
- else:
- F = frame_num
- target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
- size[1] // self.vae_stride[1],
- size[0] // self.vae_stride[2])
-
- seq_len = math.ceil((target_shape[2] * target_shape[3]) /
- (self.patch_size[1] * self.patch_size[2]) *
- target_shape[1])
-
- context = [u.to(self.dtype) for u in context]
- context_null = [u.to(self.dtype) for u in context_null]
-
- noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ]
-
- # evaluation mode
-
- # if sample_solver == 'unipc':
- # sample_scheduler = FlowUniPCMultistepScheduler(
- # num_train_timesteps=self.num_train_timesteps,
- # shift=1,
- # use_dynamic_shifting=False)
- # sample_scheduler.set_timesteps(
- # sampling_steps, device=self.device, shift=shift)
- # timesteps = sample_scheduler.timesteps
- # elif sample_solver == 'dpm++':
- # sample_scheduler = FlowDPMSolverMultistepScheduler(
- # num_train_timesteps=self.num_train_timesteps,
- # shift=1,
- # use_dynamic_shifting=False)
- # sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
- # timesteps, _ = retrieve_timesteps(
- # sample_scheduler,
- # device=self.device,
- # sigmas=sampling_sigmas)
- # else:
- # raise NotImplementedError("Unsupported solver.")
-
- # sample videos
- latents = noise
- del noise
- batch_size =len(latents)
- if target_camera != None:
- shape = list(latents[0].shape[1:])
- shape[0] *= 2
- freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False)
- else:
- freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx)
- # arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback}
- # arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
- # arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback}
-
- i2v_extra_kwrags = {}
-
- if target_camera != None:
- recam_dict = {'cam_emb': cam_emb}
- i2v_extra_kwrags.update(recam_dict)
-
- if input_frames != None:
- vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale}
- i2v_extra_kwrags.update(vace_dict)
-
-
- latent_length = (num_frames - 1) // 4 + 1
- latent_height = height // 8
- latent_width = width // 8
- if ar_step == 0:
- causal_block_size = 1
- fps_embeds = [fps] #* prompt_embeds[0].shape[0]
- fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
-
- self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
- init_timesteps = self.scheduler.timesteps
- base_num_frames_iter = latent_length
- latent_shape = [16, base_num_frames_iter, latent_height, latent_width]
-
- prefix_video = None
- predix_video_latent_length = 0
-
- if prefix_video is not None:
- latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32)
- step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
- base_num_frames_iter,
- init_timesteps,
- base_num_frames_iter,
- ar_step,
- predix_video_latent_length,
- causal_block_size,
- )
- sample_schedulers = []
- for _ in range(base_num_frames_iter):
- sample_scheduler = FlowUniPCMultistepScheduler(
- num_train_timesteps=1000, shift=1, use_dynamic_shifting=False
- )
- sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift)
- sample_schedulers.append(sample_scheduler)
- sample_schedulers_counter = [0] * base_num_frames_iter
-
- updated_num_steps= len(step_matrix)
-
- if callback != None:
- callback(-1, None, True, override_num_inference_steps = updated_num_steps)
- if self.model.enable_teacache:
- self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier)
- # if callback != None:
- # callback(-1, None, True)
-
- for i, timestep_i in enumerate(tqdm(step_matrix)):
- update_mask_i = step_update_mask[i]
- valid_interval_i = valid_interval[i]
- valid_interval_start, valid_interval_end = valid_interval_i
- timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone()
- latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()]
- if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length:
- noise_factor = 0.001 * addnoise_condition
- timestep_for_noised_condition = addnoise_condition
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = (
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
- * (1.0 - noise_factor)
- + torch.randn_like(
- latent_model_input[0][:, valid_interval_start:predix_video_latent_length]
- )
- * noise_factor
- )
- timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition
- kwrags = {
- "x" : torch.stack([latent_model_input[0]]),
- "t" : timestep,
- "freqs" :freqs,
- "fps" : fps_embeds,
- "causal_block_size" : causal_block_size,
- "causal_attention" : causal_attention,
- "callback" : callback,
- "pipeline" : self,
- "current_step" : i,
- }
- kwrags.update(i2v_extra_kwrags)
-
- if not self.do_classifier_free_guidance:
- noise_pred = self.model(
- context=context,
- **kwrags,
- )[0]
- if self._interrupt:
- return None
- noise_pred= noise_pred.to(torch.float32)
- else:
- if joint_pass:
- noise_pred_cond, noise_pred_uncond = self.model(
- context=context,
- context2=context_null,
- **kwrags,
- )
- if self._interrupt:
- return None
- else:
- noise_pred_cond = self.model(
- context=context,
- **kwrags,
- )[0]
- if self._interrupt:
- return None
- noise_pred_uncond = self.model(
- context=context_null,
- )[0]
- if self._interrupt:
- return None
- noise_pred_cond= noise_pred_cond.to(torch.float32)
- noise_pred_uncond= noise_pred_uncond.to(torch.float32)
- noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond)
- del noise_pred_cond, noise_pred_uncond
- for idx in range(valid_interval_start, valid_interval_end):
- if update_mask_i[idx].item():
- latents[0][:, idx] = sample_schedulers[idx].step(
- noise_pred[:, idx - valid_interval_start],
- timestep_i[idx],
- latents[0][:, idx],
- return_dict=False,
- generator=seed_g,
- )[0]
- sample_schedulers_counter[idx] += 1
- if callback is not None:
- callback(i, latents[0].squeeze(0), False)
-
- # for i, t in enumerate(tqdm(timesteps)):
- # if target_camera != None:
- # latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )]
- # else:
- # latent_model_input = latents
- # slg_layers_local = None
- # if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps):
- # slg_layers_local = slg_layers
- # timestep = [t]
- # offload.set_step_no_for_lora(self.model, i)
- # timestep = torch.stack(timestep)
-
- # if joint_pass:
- # noise_pred_cond, noise_pred_uncond = self.model(
- # latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both)
- # if self._interrupt:
- # return None
- # else:
- # noise_pred_cond = self.model(
- # latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0]
- # if self._interrupt:
- # return None
- # noise_pred_uncond = self.model(
- # latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0]
- # if self._interrupt:
- # return None
-
- # # del latent_model_input
-
- # # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/
- # noise_pred_text = noise_pred_cond
- # if cfg_star_switch:
- # positive_flat = noise_pred_text.view(batch_size, -1)
- # negative_flat = noise_pred_uncond.view(batch_size, -1)
-
- # alpha = optimized_scale(positive_flat,negative_flat)
- # alpha = alpha.view(batch_size, 1, 1, 1)
-
- # if (i <= cfg_zero_step):
- # noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred...
- # else:
- # noise_pred_uncond *= alpha
- # noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond)
- # del noise_pred_uncond
-
- # temp_x0 = sample_scheduler.step(
- # noise_pred[:, :target_shape[1]].unsqueeze(0),
- # t,
- # latents[0].unsqueeze(0),
- # return_dict=False,
- # generator=seed_g)[0]
- # latents = [temp_x0.squeeze(0)]
- # del temp_x0
-
- # if callback is not None:
- # callback(i, latents[0], False)
-
- x0 = latents
-
- if input_frames == None:
- videos = self.vae.decode(x0, VAE_tile_size)
- else:
- videos = self.decode_latent(x0, input_ref_images, VAE_tile_size)
-
- del latents
- del sample_scheduler
-
- return videos[0] if self.rank == 0 else None
-
- def adapt_vace_model(self):
- model = self.model
- modules_dict= { k: m for k, m in model.named_modules()}
- for model_layer, vace_layer in model.vace_layers_mapping.items():
- module = modules_dict[f"vace_blocks.{vace_layer}"]
- target = modules_dict[f"blocks.{model_layer}"]
- setattr(target, "vace", module )
- delattr(model, "vace_blocks")
-
-
\ No newline at end of file
diff --git a/wgp.py b/wgp.py
index 5267ca35b..d1eae3b58 100644
--- a/wgp.py
+++ b/wgp.py
@@ -5075,7 +5075,7 @@ def remove_temp_filenames(temp_filenames_list):
any_guide_padding = model_def.get("pad_guide_video", False)
from shared.utils.utils import prepare_video_guide_and_mask
src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]),
- [video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]),
+ [video_mask_processed] + ([] if video_guide_processed2 is None else [video_mask_processed2]),
None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide,
image_size, current_video_length, latent_size,
any_mask, any_guide_padding, guide_inpaint_color,
@@ -5097,9 +5097,12 @@ def remove_temp_filenames(temp_filenames_list):
src_faces = src_faces[:, :src_video.shape[1]]
if video_guide is not None or len(frames_to_inject_parsed) > 0:
if args.save_masks:
- if src_video is not None: save_video( src_video, "masked_frames.mp4", fps)
- if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps)
- if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
+ if src_video is not None:
+ save_video( src_video, "masked_frames.mp4", fps)
+ if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1))
+ if src_video2 is not None:
+ save_video( src_video2, "masked_frames2.mp4", fps)
+ if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1))
if video_guide is not None:
preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
From ba764e8d60552fee29190898028e46ee67a61fd5 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 25 Sep 2025 23:21:51 +0200
Subject: [PATCH 041/155] more spoil
---
README.md | 10 ++++-
defaults/animate.json | 6 ++-
models/flux/sampling.py | 4 +-
models/qwen/pipeline_qwenimage.py | 4 +-
models/wan/any2video.py | 51 ++++++++++++++++-------
models/wan/modules/model.py | 15 +++----
models/wan/wan_handler.py | 32 +++++++++++----
shared/utils/loras_mutipliers.py | 5 ++-
wgp.py | 67 ++++++++++++++++++-------------
9 files changed, 129 insertions(+), 65 deletions(-)
diff --git a/README.md b/README.md
index 9b8f70d45..d9be0fdf7 100644
--- a/README.md
+++ b/README.md
@@ -20,21 +20,27 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
-### September 25 2025: WanGP v8.73 - Here Are ~~Two~~Three New Contenders in the Vace Arena !
+### September 25 2025: WanGP v8.73 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
-So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
+So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion.
+With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise)
+
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
Also because I wanted to spoil you:
- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version.
+- **T2V Video 2 Video Masking**: ever wanted to apply a Lora, a process (for instance Upsampling) or a Text Prompt on only a (moving) part of a Source Video. Look no further, I have added *Masked Video 2 Video* (which works also in image2image) in the *Text 2 Video* models. As usual you just need to use *Matanyone* to creatre the mask.
+
+
*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
*Update 8.72*: shadow drop of Qwen Edit Plus
*Update 8.73*: Qwen Preview & InfiniteTalk Start image
+*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video
### September 15 2025: WanGP v8.6 - Attack of the Clones
diff --git a/defaults/animate.json b/defaults/animate.json
index fee26d76a..bf45f4a92 100644
--- a/defaults/animate.json
+++ b/defaults/animate.json
@@ -2,12 +2,16 @@
"model": {
"name": "Wan2.2 Animate",
"architecture": "animate",
- "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.",
+ "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'Animation' or 'Replacement' mode. Sliding Window of 81 frames at least are recommeded to obtain the best Style continuity.",
"URLs": [
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors"
],
+ "preload_URLs" :
+ [
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors"
+ ],
"group": "wan2_2"
}
}
\ No newline at end of file
diff --git a/models/flux/sampling.py b/models/flux/sampling.py
index 92ee59099..939543c64 100644
--- a/models/flux/sampling.py
+++ b/models/flux/sampling.py
@@ -292,7 +292,7 @@ def denoise(
callback(-1, None, True)
original_image_latents = None if img_cond_seq is None else img_cond_seq.clone()
-
+ original_timesteps = timesteps
morph, first_step = False, 0
if img_msk_latents is not None:
randn = torch.randn_like(original_image_latents)
@@ -309,7 +309,7 @@ def denoise(
updated_num_steps= len(timesteps) -1
if callback != None:
from shared.utils.loras_mutipliers import update_loras_slists
- update_loras_slists(model, loras_slists, updated_num_steps)
+ update_loras_slists(model, loras_slists, len(original_timesteps))
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
from mmgp import offload
# this is ignored for schnell
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index 134cc51b5..02fd473b4 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -825,7 +825,7 @@ def __call__(
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
-
+ original_timesteps = timesteps
# handle guidance
if self.transformer.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
@@ -860,7 +860,7 @@ def __call__(
updated_num_steps= len(timesteps)
if callback != None:
from shared.utils.loras_mutipliers import update_loras_slists
- update_loras_slists(self.transformer, loras_slists, updated_num_steps)
+ update_loras_slists(self.transformer, loras_slists, len(original_timesteps))
callback(-1, None, True, override_num_inference_steps = updated_num_steps)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 6b4ae629d..285340d39 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -34,6 +34,7 @@
from shared.utils.basic_flowmatch import FlowMatchScheduler
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
+from shared.utils.audio_video import save_video
from mmgp import safetensors2
from shared.utils.audio_video import save_video
@@ -434,7 +435,7 @@ def generate(self,
start_step_no = 0
ref_images_count = 0
trim_frames = 0
- extended_overlapped_latents = clip_image_start = clip_image_end = None
+ extended_overlapped_latents = clip_image_start = clip_image_end = image_mask_latents = None
no_noise_latents_injection = infinitetalk
timestep_injection = False
lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1
@@ -585,7 +586,7 @@ def generate(self,
kwargs['cam_emb'] = cam_emb
# Video 2 Video
- if denoising_strength < 1. and input_frames != None:
+ if "G" in video_prompt_type and input_frames != None:
height, width = input_frames.shape[-2:]
source_latents = self.vae.encode([input_frames])[0].unsqueeze(0)
injection_denoising_step = 0
@@ -616,6 +617,14 @@ def generate(self,
if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:]
injection_denoising_step = 0
+ if input_masks is not None and not "U" in video_prompt_type:
+ image_mask_latents = torch.nn.functional.interpolate(input_masks, size= source_latents.shape[-2:], mode="nearest").unsqueeze(0)
+ if image_mask_latents.shape[2] !=1:
+ image_mask_latents = torch.cat([ image_mask_latents[:,:, :1], torch.nn.functional.interpolate(image_mask_latents, size= (source_latents.shape[-3]-1, *source_latents.shape[-2:]), mode="nearest") ], dim=2)
+ image_mask_latents = torch.where(image_mask_latents>=0.5, 1., 0. )[:1].to(self.device)
+ # save_video(image_mask_latents.squeeze(0), "mama.mp4", value_range=(0,1) )
+ # image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0)
+
# Phantom
if phantom:
lat_input_ref_images_neg = None
@@ -737,7 +746,7 @@ def generate(self,
denoising_extra = ""
from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps
- phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold )
+ phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(original_timesteps,guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold )
if len(phases_description) > 0: set_header_text(phases_description)
guidance_switch_done = guidance_switch2_done = False
if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}"
@@ -748,8 +757,8 @@ def update_guidance(step_no, t, guide_scale, new_guide_scale, guidance_switch_do
denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}"
callback(step_no-1, denoising_extra = denoising_extra)
return guide_scale, guidance_switch_done, trans, denoising_extra
- update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
- if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
+ update_loras_slists(self.model, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
+ if self.model2 is not None: update_loras_slists(self.model2, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2)
callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra)
def clear():
@@ -762,6 +771,7 @@ def clear():
scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g}
# b, c, lat_f, lat_h, lat_w
latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
+ if "G" in video_prompt_type: randn = latents
if apg_switch != 0:
apg_momentum = -0.75
apg_norm_threshold = 55
@@ -784,23 +794,21 @@ def clear():
timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device)
timestep[:source_latents.shape[2]] = 0
- kwargs.update({"t": timestep, "current_step": start_step_no + i})
+ kwargs.update({"t": timestep, "current_step_no": i, "real_step_no": start_step_no + i })
kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None
if denoising_strength < 1 and i <= injection_denoising_step:
sigma = t / 1000
- noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g)
if inject_from_start:
- new_latents = latents.clone()
- new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents
+ noisy_image = latents.clone()
+ noisy_image[:,:, :source_latents.shape[2] ] = randn[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents
for latent_no, keep_latent in enumerate(latent_keep_frames):
if not keep_latent:
- new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
- latents = new_latents
- new_latents = None
+ noisy_image[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1]
+ latents = noisy_image
+ noisy_image = None
else:
- latents = noise * sigma + (1 - sigma) * source_latents
- noise = None
+ latents = randn * sigma + (1 - sigma) * source_latents
if extended_overlapped_latents != None:
if no_noise_latents_injection:
@@ -940,6 +948,13 @@ def clear():
latents,
**scheduler_kwargs)[0]
+
+ if image_mask_latents is not None:
+ sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000
+ noisy_image = randn * sigma + (1 - sigma) * source_latents
+ latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents
+
+
if callback is not None:
latents_preview = latents
if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ]
@@ -999,3 +1014,11 @@ def adapt_animate_model(self, model):
setattr(target, "face_adapter_fuser_blocks", module )
delattr(model, "face_adapter")
+ def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs):
+ if base_model_type == "animate":
+ if "1" in video_prompt_type:
+ preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
+ if len(preloadURLs) > 0:
+ return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
+ return [], []
+
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index c98087b5d..cd024704c 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -1220,7 +1220,8 @@ def forward(
y=None,
freqs = None,
pipeline = None,
- current_step = 0,
+ current_step_no = 0,
+ real_step_no = 0,
x_id= 0,
max_steps = 0,
slg_layers=None,
@@ -1310,9 +1311,9 @@ def forward(
del causal_mask
offload.shared_state["embed_sizes"] = grid_sizes
- offload.shared_state["step_no"] = current_step
+ offload.shared_state["step_no"] = real_step_no
offload.shared_state["max_steps"] = max_steps
- if current_step == 0 and x_id == 0: clear_caches()
+ if current_step_no == 0 and x_id == 0: clear_caches()
# arguments
kwargs = dict(
@@ -1336,7 +1337,7 @@ def forward(
if standin_ref is not None:
standin_cache_enabled = False
kwargs["standin_phase"] = 2
- if current_step == 0 or not standin_cache_enabled :
+ if current_step_no == 0 or not standin_cache_enabled :
standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2)
standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) )
standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
@@ -1401,7 +1402,7 @@ def forward(
skips_steps_cache = self.cache
if skips_steps_cache != None:
if skips_steps_cache.cache_type == "mag":
- if current_step <= skips_steps_cache.start_step:
+ if real_step_no <= skips_steps_cache.start_step:
should_calc = True
elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all
assert len(x_list) == 1
@@ -1410,7 +1411,7 @@ def forward(
x_should_calc = []
for i in range(1 if skips_steps_cache.one_for_all else len(x_list)):
cur_x_id = i if joint_pass else x_id
- cur_mag_ratio = skips_steps_cache.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list
+ cur_mag_ratio = skips_steps_cache.mag_ratios[real_step_no * 2 + cur_x_id] # conditional and unconditional in one list
skips_steps_cache.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step
skips_steps_cache.accumulated_steps[cur_x_id] += 1 # skip steps plus 1
cur_skip_err = np.abs(1-skips_steps_cache.accumulated_ratio[cur_x_id]) # skip error of current steps
@@ -1430,7 +1431,7 @@ def forward(
if x_id != 0:
should_calc = skips_steps_cache.should_calc
else:
- if current_step <= skips_steps_cache.start_step or current_step == skips_steps_cache.num_steps-1:
+ if real_step_no <= skips_steps_cache.start_step or real_step_no == skips_steps_cache.num_steps-1:
should_calc = True
skips_steps_cache.accumulated_rel_l1_distance = 0
else:
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 574c990a0..c5875445b 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -124,12 +124,20 @@ def query_model_def(base_model_type, model_def):
if base_model_type in ["t2v"]:
extra_model_def["guide_custom_choices"] = {
- "choices":[("Use Text Prompt Only", ""),("Video to Video guided by Text Prompt", "GUV")],
+ "choices":[("Use Text Prompt Only", ""),
+ ("Video to Video guided by Text Prompt", "GUV"),
+ ("Video to Video guided by Text Prompt and Restricted to the Area of the Video Mask", "GVA")],
"default": "",
- "letters_filter": "GUV",
+ "letters_filter": "GUVA",
"label": "Video to Video"
}
+ extra_model_def["mask_preprocessing"] = {
+ "selection":[ "", "A"],
+ "visible": False
+ }
+
+
if base_model_type in ["infinitetalk"]:
extra_model_def["no_background_removal"] = True
extra_model_def["all_image_refs_are_background_ref"] = True
@@ -160,14 +168,22 @@ def query_model_def(base_model_type, model_def):
if base_model_type in ["animate"]:
extra_model_def["guide_custom_choices"] = {
"choices":[
- ("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"),
- ("Replace Person in Control Video Person in Reference Image", "PVBAI"),
+ ("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"),
+ ("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"),
+ ("Replace Person in Control Video by Person in Reference Image", "PVBAI"),
+ ("Replace Person in Control Video by Person in Reference Image and Apply Relighting Process", "PVBAI1"),
],
- "default": "KI",
- "letters_filter": "PVBXAKI",
+ "default": "PVBKI",
+ "letters_filter": "PVBXAKI1",
"label": "Type of Process",
"show_label" : False,
}
+
+ extra_model_def["mask_preprocessing"] = {
+ "selection":[ "", "A", "XA"],
+ "visible": False
+ }
+
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["extract_guide_from_window_start"] = True
@@ -480,8 +496,8 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"video_prompt_type": "UV",
})
elif base_model_type in ["animate"]:
- ui_defaults.update({
- "video_prompt_type": "PVBXAKI",
+ ui_defaults.update({
+ "video_prompt_type": "PVBKI",
"mask_expand": 20,
"audio_prompt_type": "R",
})
diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py
index 58cc9a993..e0cec6a45 100644
--- a/shared/utils/loras_mutipliers.py
+++ b/shared/utils/loras_mutipliers.py
@@ -99,7 +99,7 @@ def is_float(element: any) -> bool:
return loras_list_mult_choices_nums, slists_dict, ""
-def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ):
+def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None):
from mmgp import offload
sz = len(slists_dict["phase1"])
slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ]
@@ -108,7 +108,8 @@ def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_st
-def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
+def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ):
+ total_num_steps = len(timesteps)
model_switch_step = model_switch_step2 = None
for i, t in enumerate(timesteps):
if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i
diff --git a/wgp.py b/wgp.py
index d1eae3b58..def856ff6 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.73"
+WanGP_version = "8.74"
settings_version = 2.36
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -3699,13 +3699,15 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
any_mask = input_mask_path != None
video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps)
if len(video) == 0: return None
+ frame_height, frame_width, _ = video[0].shape
+ num_frames = len(video)
if any_mask:
mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps)
- frame_height, frame_width, _ = video[0].shape
-
- num_frames = min(len(video), len(mask_video))
+ num_frames = min(num_frames, len(mask_video))
if num_frames == 0: return None
- video, mask_video = video[:num_frames], mask_video[:num_frames]
+ video = video[:num_frames]
+ if any_mask:
+ mask_video = mask_video[:num_frames]
from preprocessing.face_preprocessor import FaceProcessor
face_processor = FaceProcessor()
@@ -6970,39 +6972,49 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
all_guide_processes ="PDESLCMUVB"
-def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
+def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
+ model_type = state["model_type"]
+ model_def = get_model_def(model_type)
old_video_prompt_type = video_prompt_type
- video_prompt_type = del_in_sequence(video_prompt_type, all_guide_processes)
+ if filter_type == "alt":
+ guide_custom_choices = model_def.get("guide_custom_choices",{})
+ letter_filter = guide_custom_choices.get("letters_filter","")
+ else:
+ letter_filter = all_guide_processes
+ video_prompt_type = del_in_sequence(video_prompt_type, letter_filter)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
visible = "V" in video_prompt_type
- model_type = state["model_type"]
- model_def = get_model_def(model_type)
any_outpainting= image_mode in model_def.get("video_guide_outpainting", [])
mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type
image_outputs = image_mode > 0
keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value )
+ # mask_video_input_visible = image_mode == 0 and mask_visible
mask_preprocessing = model_def.get("mask_preprocessing", None)
if mask_preprocessing is not None:
mask_selector_visible = mask_preprocessing.get("visible", True)
else:
mask_selector_visible = True
- return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible)
-
-
-def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode):
- model_def = get_model_def(state["model_type"])
- guide_custom_choices = model_def.get("guide_custom_choices",{})
- video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter",""))
- video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
- control_video_visible = "V" in video_prompt_type
ref_images_visible = "I" in video_prompt_type
- denoising_strength_visible = "G" in video_prompt_type
- return video_prompt_type, gr.update(visible = control_video_visible and image_mode ==0), gr.update(visible = control_video_visible and image_mode >=1), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible )
-
-# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide):
-# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0]
-# return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide)
+ return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible )
+
+
+# def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
+# old_video_prompt_type = video_prompt_type
+# model_def = get_model_def(state["model_type"])
+# guide_custom_choices = model_def.get("guide_custom_choices",{})
+# video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter",""))
+# video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt)
+# image_outputs = image_mode > 0
+# control_video_visible = "V" in video_prompt_type
+# ref_images_visible = "I" in video_prompt_type
+# denoising_strength_visible = "G" in video_prompt_type
+# mask_expand_visible = control_video_visible and "A" in video_prompt_type and not "U" in video_prompt_type
+# mask_video_input_visible = image_mode == 0 and mask_expand_visible
+# image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value )
+# keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False)
+
+# return video_prompt_type, gr.update(visible = control_video_visible and image_mode ==0), gr.update(visible = control_video_visible and image_mode >=1), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible ), gr.update(visible = mask_video_input_visible ), gr.update(visible = mask_expand_visible), image_mask_guide, image_guide, image_mask, gr.update(visible = keep_frames_video_guide_visible)
def refresh_preview(state):
gen = get_gen_info(state)
@@ -8294,9 +8306,10 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" )
image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] )
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden")
- video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand], show_progress="hidden")
- video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength ], show_progress="hidden")
- video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden")
+ video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden")
+ video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden")
+ # video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength, video_mask, mask_expand, image_mask_guide, image_guide, image_mask, keep_frames_video_guide ], show_progress="hidden")
+ video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden")
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden")
video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" )
From bb7e84719dc148b7faef04c33836957f275e87a8 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 26 Sep 2025 00:17:11 +0200
Subject: [PATCH 042/155] better packaging of matanyone
---
preprocessing/matanyone/app.py | 3 ++-
wgp.py | 8 +++++++-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py
index df71b40a9..a5d557099 100644
--- a/preprocessing/matanyone/app.py
+++ b/preprocessing/matanyone/app.py
@@ -697,7 +697,8 @@ def load_unload_models(selected):
model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device)
model_in_GPU = True
from .matanyone.model.matanyone import MatAnyone
- matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
+ # matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone")
+ matanyone_model = MatAnyone.from_pretrained("ckpts/mask")
# pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model }
# offload.profile(pipe)
matanyone_model = matanyone_model.to("cpu").eval()
diff --git a/wgp.py b/wgp.py
index def856ff6..e00660c89 100644
--- a/wgp.py
+++ b/wgp.py
@@ -2519,6 +2519,7 @@ def download_mmaudio():
}
process_files_def(**enhancer_def)
+download_shared_done = False
def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1):
def computeList(filename):
if filename == None:
@@ -2536,7 +2537,7 @@ def computeList(filename):
"repoId" : "DeepBeepMeep/Wan2.1",
"sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "roformer", "pyannote", "det_align", "" ],
"fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"],
- ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"],
+ ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors", "model.safetensors", "config.json"],
["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"],
["config.json", "pytorch_model.bin", "preprocessor_config.json"],
["model_bs_roformer_ep_317_sdr_12.9755.ckpt", "model_bs_roformer_ep_317_sdr_12.9755.yaml", "download_checks.json"],
@@ -2562,6 +2563,9 @@ def computeList(filename):
process_files_def(**enhancer_def)
download_mmaudio()
+ global download_shared_done
+ download_shared_done = True
+
if model_filename is None: return
def download_file(url,filename):
@@ -9138,6 +9142,8 @@ def set_new_tab(tab_state, new_tab_no):
tab_state["tab_no"] = 0
return gr.Tabs(selected="video_gen")
else:
+ if not download_shared_done:
+ download_models()
vmc_event_handler(True)
tab_state["tab_no"] = new_tab_no
return gr.Tabs()
From 6dfd17315246f2408e3232260bcf6c5a66cec8b4 Mon Sep 17 00:00:00 2001
From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com>
Date: Sat, 27 Sep 2025 12:57:04 +0200
Subject: [PATCH 043/155] Update README.md
added docker info
---
README.md | 26 ++++++++++++++++++++++++++
1 file changed, 26 insertions(+)
diff --git a/README.md b/README.md
index d9be0fdf7..3ee2d88dc 100644
--- a/README.md
+++ b/README.md
@@ -148,6 +148,32 @@ git pull
pip install -r requirements.txt
```
+## 🐳 Docker:
+
+**For Debian-based systems (Ubuntu, Debian, etc.):**
+
+```bash
+./run-docker-cuda-deb.sh
+```
+
+This automated script will:
+
+- Detect your GPU model and VRAM automatically
+- Select optimal CUDA architecture for your GPU
+- Install NVIDIA Docker runtime if needed
+- Build a Docker image with all dependencies
+- Run WanGP with optimal settings for your hardware
+
+**Docker environment includes:**
+
+- NVIDIA CUDA 12.4.1 with cuDNN support
+- PyTorch 2.6.0 with CUDA 12.4 support
+- SageAttention compiled for your specific GPU architecture
+- Optimized environment variables for performance (TF32, threading, etc.)
+- Automatic cache directory mounting for faster subsequent runs
+- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally
+
+**Supported GPUs:** RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more.
## 📦 Installation
From 79df3aae64c2c3b28286d921a6e7b13ee02a375e Mon Sep 17 00:00:00 2001
From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com>
Date: Sat, 27 Sep 2025 13:00:08 +0200
Subject: [PATCH 044/155] Update README.md
---
README.md | 27 ---------------------------
1 file changed, 27 deletions(-)
diff --git a/README.md b/README.md
index 35fc0dc59..8261540ce 100644
--- a/README.md
+++ b/README.md
@@ -122,33 +122,6 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)**
## 🚀 Quick Start
-### 🐳 Docker:
-
-**For Debian-based systems (Ubuntu, Debian, etc.):**
-
-```bash
-./run-docker-cuda-deb.sh
-```
-
-This automated script will:
-
-- Detect your GPU model and VRAM automatically
-- Select optimal CUDA architecture for your GPU
-- Install NVIDIA Docker runtime if needed
-- Build a Docker image with all dependencies
-- Run WanGP with optimal settings for your hardware
-
-**Docker environment includes:**
-
-- NVIDIA CUDA 12.4.1 with cuDNN support
-- PyTorch 2.6.0 with CUDA 12.4 support
-- SageAttention compiled for your specific GPU architecture
-- Optimized environment variables for performance (TF32, threading, etc.)
-- Automatic cache directory mounting for faster subsequent runs
-- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally
-
-**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more.
-
**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/)
**Manual installation:**
From 5ce8fc3d535c9285d23c7ca06b3016527ba59743 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sat, 27 Sep 2025 15:14:55 +0200
Subject: [PATCH 045/155] various fixes
---
README.md | 7 ++++---
models/qwen/pipeline_qwenimage.py | 4 ++--
models/wan/any2video.py | 11 ++++++++++-
models/wan/modules/model.py | 16 +++++++++++-----
models/wan/wan_handler.py | 9 +++++++++
shared/attention.py | 13 +++++++++++++
shared/sage2_core.py | 6 ++++++
wgp.py | 16 +++++++++++-----
8 files changed, 66 insertions(+), 16 deletions(-)
diff --git a/README.md b/README.md
index 8261540ce..939cb7f04 100644
--- a/README.md
+++ b/README.md
@@ -122,7 +122,9 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)**
## 🚀 Quick Start
-**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/)
+**One-click installation:**
+- Get started instantly with [Pinokio App](https://pinokio.computer/)
+- Use Redtash1 [One Click Install with Sage](https://github.com/Redtash1/Wan2GP-Windows-One-Click-Install-With-Sage)
**Manual installation:**
```bash
@@ -136,8 +138,7 @@ pip install -r requirements.txt
**Run the application:**
```bash
-python wgp.py # Text-to-video (default)
-python wgp.py --i2v # Image-to-video
+python wgp.py
```
**Update the application:**
diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py
index 02fd473b4..c458852fe 100644
--- a/models/qwen/pipeline_qwenimage.py
+++ b/models/qwen/pipeline_qwenimage.py
@@ -232,7 +232,7 @@ def _get_qwen_prompt_embeds(
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
- if self.processor is not None and image is not None:
+ if self.processor is not None and image is not None and len(image) > 0:
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
if isinstance(image, list):
base_img_prompt = ""
@@ -972,7 +972,7 @@ def cfg_predictions( noise_pred, neg_noise_pred, guidance, t):
if callback is not None:
preview = self._unpack_latents(latents, height, width, self.vae_scale_factor)
- preview = preview.squeeze(0)
+ preview = preview.transpose(0,2).squeeze(0)
callback(i, preview, False)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 285340d39..7dc4b19c8 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -595,7 +595,7 @@ def generate(self,
color_reference_frame = input_frames[:, -1:].clone()
if prefix_frames_count > 0:
overlapped_frames_num = prefix_frames_count
- overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1
+ overlapped_latents_frames_num = (overlapped_frames_num -1 // 4) + 1
# overlapped_latents_frames_num = overlapped_latents.shape[2]
# overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1
else:
@@ -735,11 +735,20 @@ def generate(self,
if callback != None:
callback(-1, None, True)
+
+ clear_caches()
offload.shared_state["_chipmunk"] = False
chipmunk = offload.shared_state.get("_chipmunk", False)
if chipmunk:
self.model.setup_chipmunk()
+ offload.shared_state["_radial"] = offload.shared_state["_attention"]=="radial"
+ radial = offload.shared_state.get("_radial", False)
+ if radial:
+ radial_cache = get_cache("radial")
+ from shared.radial_attention.attention import fill_radial_cache
+ fill_radial_cache(radial_cache, len(self.model.blocks), *target_shape[1:])
+
# init denoising
updated_num_steps= len(timesteps)
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index cd024704c..82a313cec 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -315,6 +315,8 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks
x_ref_attn_map = None
chipmunk = offload.shared_state.get("_chipmunk", False)
+ radial = offload.shared_state.get("_radial", False)
+
if chipmunk and self.__class__ == WanSelfAttention:
q = q.transpose(1,2)
k = k.transpose(1,2)
@@ -322,12 +324,17 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks
attn_layers = offload.shared_state["_chipmunk_layers"]
x = attn_layers[self.block_no](q, k, v)
x = x.transpose(1,2)
+ elif radial and self.__class__ == WanSelfAttention:
+ qkv_list = [q,k,v]
+ del q,k,v
+ radial_cache = get_cache("radial")
+ no_step_no = offload.shared_state["step_no"]
+ x = radial_cache[self.block_no](qkv_list=qkv_list, timestep_no=no_step_no)
elif block_mask == None:
qkv_list = [q,k,v]
del q,k,v
- x = pay_attention(
- qkv_list,
- window_size=self.window_size)
+
+ x = pay_attention( qkv_list, window_size=self.window_size)
else:
with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
@@ -1311,9 +1318,8 @@ def forward(
del causal_mask
offload.shared_state["embed_sizes"] = grid_sizes
- offload.shared_state["step_no"] = real_step_no
+ offload.shared_state["step_no"] = current_step_no
offload.shared_state["max_steps"] = max_steps
- if current_step_no == 0 and x_id == 0: clear_caches()
# arguments
kwargs = dict(
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index c5875445b..1a30fffbc 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -184,6 +184,15 @@ def query_model_def(base_model_type, model_def):
"visible": False
}
+ # extra_model_def["image_ref_choices"] = {
+ # "choices": [("None", ""),
+ # ("People / Objects", "I"),
+ # ("Landscape followed by People / Objects (if any)", "KI"),
+ # ],
+ # "visible": False,
+ # "letters_filter": "KFI",
+ # }
+
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["extract_guide_from_window_start"] = True
diff --git a/shared/attention.py b/shared/attention.py
index cc6ece02d..efaa22370 100644
--- a/shared/attention.py
+++ b/shared/attention.py
@@ -42,6 +42,11 @@ def sageattn_varlen_wrapper(
except ImportError:
sageattn_varlen_wrapper = None
+try:
+ from spas_sage_attn import block_sparse_sage2_attn_cuda
+except ImportError:
+ block_sparse_sage2_attn_cuda = None
+
try:
from .sage2_core import sageattn as sageattn2, is_sage2_supported
@@ -62,6 +67,8 @@ def sageattn2_wrapper(
return o
+from sageattn import sageattn_blackwell as sageattn3
+
try:
from sageattn import sageattn_blackwell as sageattn3
except ImportError:
@@ -144,6 +151,9 @@ def get_attention_modes():
ret.append("sage")
if sageattn2 != None and version("sageattention").startswith("2") :
ret.append("sage2")
+ if block_sparse_sage2_attn_cuda != None and version("sageattention").startswith("2") :
+ ret.append("radial")
+
if sageattn3 != None: # and version("sageattention").startswith("3") :
ret.append("sage3")
@@ -159,6 +169,8 @@ def get_supported_attention_modes():
if not sage2_supported:
if "sage2" in ret:
ret.remove("sage2")
+ if "radial" in ret:
+ ret.remove("radial")
if major < 7:
if "sage" in ret:
@@ -225,6 +237,7 @@ def pay_attention(
if attn == "chipmunk":
from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn
from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG
+ if attn == "radial": attn ="sage2"
if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"):
assert attention_mask == None
diff --git a/shared/sage2_core.py b/shared/sage2_core.py
index 04b77f8bb..33cd98cbf 100644
--- a/shared/sage2_core.py
+++ b/shared/sage2_core.py
@@ -28,18 +28,24 @@
try:
from sageattention import _qattn_sm80
+ if not hasattr(_qattn_sm80, "qk_int8_sv_f16_accum_f32_attn"):
+ _qattn_sm80 = torch.ops.sageattention_qattn_sm80
SM80_ENABLED = True
except:
SM80_ENABLED = False
try:
from sageattention import _qattn_sm89
+ if not hasattr(_qattn_sm89, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"):
+ _qattn_sm89 = torch.ops.sageattention_qattn_sm89
SM89_ENABLED = True
except:
SM89_ENABLED = False
try:
from sageattention import _qattn_sm90
+ if not hasattr(_qattn_sm90, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"):
+ _qattn_sm90 = torch.ops.sageattention_qattn_sm90
SM90_ENABLED = True
except:
SM90_ENABLED = False
diff --git a/wgp.py b/wgp.py
index e00660c89..6a257e814 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.74"
+WanGP_version = "8.75"
settings_version = 2.36
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -1394,6 +1394,12 @@ def _parse_args():
help="View form generation / refresh time"
)
+ parser.add_argument(
+ "--betatest",
+ action="store_true",
+ help="test unreleased features"
+ )
+
parser.add_argument(
"--vram-safety-coefficient",
type=float,
@@ -4367,7 +4373,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri
if image_start is None or not "I" in prompt_enhancer:
image_start = [None] * num_prompts
else:
- image_start = [img[0] for img in image_start]
+ image_start = [convert_image(img[0]) for img in image_start]
if len(image_start) == 1:
image_start = image_start * num_prompts
else:
@@ -8712,9 +8718,9 @@ def check(mode):
("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"),
("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"),
("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"),
- ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"),
- ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"),
- ],
+ ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2")]\
+ + ([("Radial" + check("radial")+ ": x? faster but ? quality - requires Sparge & Sage 2 Attn (usually complex to set up on Windows without WSL)", "radial")] if args.betatest else [])\
+ + [("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3")],
value= attention_mode,
label="Attention Type",
interactive= not lock_ui_attention
From d62748a226e4eab35cdf500231e1234893a5ea09 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sat, 27 Sep 2025 15:22:31 +0200
Subject: [PATCH 046/155] missed files
---
shared/radial_attention/attention.py | 49 ++
shared/radial_attention/attn_mask.py | 379 ++++++++++++++++
shared/radial_attention/inference.py | 41 ++
shared/radial_attention/sparse_transformer.py | 424 ++++++++++++++++++
shared/radial_attention/utils.py | 16 +
5 files changed, 909 insertions(+)
create mode 100644 shared/radial_attention/attention.py
create mode 100644 shared/radial_attention/attn_mask.py
create mode 100644 shared/radial_attention/inference.py
create mode 100644 shared/radial_attention/sparse_transformer.py
create mode 100644 shared/radial_attention/utils.py
diff --git a/shared/radial_attention/attention.py b/shared/radial_attention/attention.py
new file mode 100644
index 000000000..ab4a5a293
--- /dev/null
+++ b/shared/radial_attention/attention.py
@@ -0,0 +1,49 @@
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from .attn_mask import RadialAttention, MaskMap
+
+def fill_radial_cache(radial_cache, nb_layers, lat_t, lat_h, lat_w):
+ MaskMap._log_mask = None
+
+ for i in range(nb_layers):
+ radial_cache[i] = WanSparseAttnProcessor2_0(i, lat_t, lat_h, lat_w)
+
+class WanSparseAttnProcessor2_0:
+ mask_map = None
+ dense_timestep = 0
+ dense_block = 0
+ decay_factor = 1.0
+ sparse_type = "radial" # default to radial attention, can be changed to "dense" for dense attention
+ use_sage_attention = True
+
+ def __init__(self, layer_idx, lat_t, lat_h, lat_w):
+ self.layer_idx = layer_idx
+ self.mask_map = MaskMap(video_token_num=lat_t * lat_h * lat_w // 4 , num_frame=lat_t)
+ def __call__(
+ self,
+ qkv_list,
+ timestep_no = 0,
+ ) -> torch.Tensor:
+ query, key, value = qkv_list
+
+ batch_size = query.shape[0]
+ # transform (batch_size, seq_len, num_heads, head_dim) to (seq_len * batch_size, num_heads, head_dim)
+ query = rearrange(query, "b s h d -> (b s) h d")
+ key = rearrange(key, "b s h d -> (b s) h d")
+ value = rearrange(value, "b s h d -> (b s) h d")
+ if timestep_no < self.dense_timestep or self.layer_idx < self.dense_block or self.sparse_type == "dense":
+ hidden_states = RadialAttention(
+ query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="dense", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention
+ )
+ else:
+ # apply radial attention
+ hidden_states = RadialAttention(
+ query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="radial", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention
+ )
+ # transform back to (batch_size, num_heads, seq_len, head_dim)
+ hidden_states = rearrange(hidden_states, "(b s) h d -> b s h d", b=batch_size)
+
+ return hidden_states
diff --git a/shared/radial_attention/attn_mask.py b/shared/radial_attention/attn_mask.py
new file mode 100644
index 000000000..4c29f1c32
--- /dev/null
+++ b/shared/radial_attention/attn_mask.py
@@ -0,0 +1,379 @@
+import torch
+# import flashinfer
+import matplotlib.pyplot as plt
+# from sparse_sageattn import sparse_sageattn
+from einops import rearrange, repeat
+from sageattention import sageattn
+from spas_sage_attn import block_sparse_sage2_attn_cuda
+
+def get_cuda_arch_versions():
+ cuda_archs = []
+ for i in range(torch.cuda.device_count()):
+ major, minor = torch.cuda.get_device_capability(i)
+ cuda_archs.append(f"sm{major}{minor}")
+ return cuda_archs
+
+from spas_sage_attn import block_sparse_sage2_attn_cuda
+
+def sparge_mask_convert(mask: torch.Tensor, block_size: int = 128, arch="sm") -> torch.Tensor:
+ assert block_size in [128, 64], "Radial Attention only supports block size of 128 or 64"
+ assert mask.shape[0] == mask.shape[1], "Input mask must be square."
+
+ if block_size == 128:
+ if arch == "sm90":
+ new_mask = torch.repeat_interleave(mask, 2, dim=0)
+ else:
+ new_mask = torch.repeat_interleave(mask, 2, dim=1)
+
+ elif block_size == 64:
+ if arch == "sm90":
+ num_row, num_col = mask.shape
+ reshaped_mask = mask.view(num_row, num_col // 2, 2)
+ new_mask = torch.max(reshaped_mask, dim=2).values
+ else:
+ num_row, num_col = mask.shape
+ reshaped_mask = mask.view(num_row // 2, 2, num_col)
+ new_mask = torch.max(reshaped_mask, dim=1).values
+
+ return new_mask
+
+def get_indptr_from_mask(mask, query):
+ # query shows the device of the indptr
+ # indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension,
+ # shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension.
+ # The first element is always 0, and the last element is the number of blocks in the row dimension.
+ # The rest of the elements are the number of blocks in each row.
+ # the mask is already a block sparse mask
+ indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32)
+ indptr[0] = 0
+ row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row]
+ indptr[1:] = torch.cumsum(row_counts, dim=0)
+ return indptr
+
+def get_indices_from_mask(mask, query):
+ # indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension,
+ # shape `(nnz,),` where `nnz` is the number of non-zero blocks.
+ # The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension.
+ nonzero_indices = torch.nonzero(mask)
+ indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device)
+ return indices
+
+def shrinkMaskStrict(mask, block_size=128):
+ seqlen = mask.shape[0]
+ block_num = seqlen // block_size
+ mask = mask[:block_num * block_size, :block_num * block_size].view(block_num, block_size, block_num, block_size)
+ col_densities = mask.sum(dim = 1) / block_size
+ # we want the minimum non-zero column density in the block
+ non_zero_densities = col_densities > 0
+ high_density_cols = col_densities > 1/3
+ frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
+ block_mask = frac_high_density_cols > 0.6
+ block_mask[0:0] = True
+ block_mask[-1:-1] = True
+ return block_mask
+
+def pad_qkv(input_tensor, block_size=128):
+ """
+ Pad the input tensor to be a multiple of the block size.
+ input shape: (seqlen, num_heads, hidden_dim)
+ """
+ seqlen, num_heads, hidden_dim = input_tensor.shape
+ # Calculate the necessary padding
+ padding_length = (block_size - (seqlen % block_size)) % block_size
+ # Create a padded tensor with zeros
+ padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype)
+ # Copy the original tensor into the padded tensor
+ padded_tensor[:seqlen, :, :] = input_tensor
+
+ return padded_tensor
+
+def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query):
+ assert(sparse_type in ["radial"])
+ dist = abs(i - j)
+ group = dist.bit_length()
+ threshold = 128 # hardcoded threshold for now, which is equal to block-size
+ decay_length = 2 ** token_per_frame.bit_length() / 2 ** group
+ if decay_length >= threshold:
+ return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
+
+ split_factor = int(threshold / decay_length)
+ modular = dist % split_factor
+ if modular == 0:
+ return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
+ else:
+ return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
+
+def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
+ assert(sparse_type in ["radial"])
+ dist = abs(i - j)
+ if model_type == "wan":
+ if dist < 1:
+ return token_per_frame
+ if dist == 1:
+ return token_per_frame // 2
+ elif model_type == "hunyuan":
+ if dist <= 1:
+ return token_per_frame
+ else:
+ raise ValueError(f"Unknown model type: {model_type}")
+ group = dist.bit_length()
+ decay_length = 2 ** token_per_frame.bit_length() / 2 ** group * decay_factor
+ threshold = block_size
+ if decay_length >= threshold:
+ return decay_length
+ else:
+ return threshold
+
+def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
+ """
+ A more memory friendly version, we generate the attention mask of each frame pair at a time,
+ shrinks it, and stores it into the final result
+ """
+ final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool)
+ token_per_frame = video_token_num // num_frame
+ video_text_border = video_token_num // block_size
+
+ col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1)
+ row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1)
+ final_log_mask[video_text_border:] = True
+ final_log_mask[:, video_text_border:] = True
+ for i in range(num_frame):
+ for j in range(num_frame):
+ local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
+ if j == 0 and model_type == "wan": # this is attention sink
+ local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool)
+ else:
+ window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
+ local_mask = torch.abs(col_indices - row_indices) <= window_width
+ split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query)
+ local_mask = torch.logical_and(local_mask, split_mask)
+
+ remainder_row = (i * token_per_frame) % block_size
+ remainder_col = (j * token_per_frame) % block_size
+ # get the padded size
+ all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
+ all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
+ padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool)
+ padded_local_mask[remainder_row:remainder_row + token_per_frame, remainder_col:remainder_col + token_per_frame] = local_mask
+ # shrink the mask
+ block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
+ # set the block mask to the final log mask
+ block_row_start = (i * token_per_frame) // block_size
+ block_col_start = (j * token_per_frame) // block_size
+ block_row_end = block_row_start + block_mask.shape[0]
+ block_col_end = block_col_start + block_mask.shape[1]
+ final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(
+ final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
+ print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}")
+ return final_log_mask
+
+class MaskMap:
+ _log_mask = None
+
+ def __init__(self, video_token_num=25440, num_frame=16):
+ self.video_token_num = video_token_num
+ self.num_frame = num_frame
+
+ def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None):
+ if MaskMap._log_mask is None:
+ MaskMap._log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool)
+ MaskMap._log_mask = gen_log_mask_shrinked(query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size)
+ return MaskMap._log_mask
+
+def SpargeSageAttnBackend(query, key, value, mask_map=None, video_mask=None, pre_defined_mask=None, block_size=128):
+ if video_mask.all():
+ # dense case
+ kv_border = pre_defined_mask[0].sum() if pre_defined_mask is not None else key.shape[0]
+ output_video = sageattn(
+ query[:mask_map.video_token_num, :, :].unsqueeze(0),
+ key[:kv_border, :, :].unsqueeze(0),
+ value[:kv_border, :, :].unsqueeze(0),
+ tensor_layout="NHD",
+ )[0]
+
+ if pre_defined_mask is not None:
+ output_text = flashinfer.single_prefill_with_kv_cache(
+ q=query[mask_map.video_token_num:, :, :],
+ k=key[:pre_defined_mask[0].sum(), :, :],
+ v=value[:pre_defined_mask[0].sum(), :, :],
+ causal=False,
+ return_lse=False,
+ )
+ return torch.cat([output_video, output_text], dim=0)
+ else:
+ return output_video
+
+ # sparse-sageattention only supports (b, h, s, d) layout, need rearrange first
+ query_hnd = rearrange(query.unsqueeze(0), "b s h d -> b h s d")
+ key_hnd = rearrange(key.unsqueeze(0), "b s h d -> b h s d")
+ value_hnd = rearrange(value.unsqueeze(0), "b s h d -> b h s d")
+ arch = get_cuda_arch_versions()[query.device.index]
+ converted_mask = repeat(sparge_mask_convert(mask=video_mask, block_size=block_size, arch=arch), "s t -> b h s t", b=query_hnd.shape[0], h=query_hnd.shape[1])
+
+ converted_mask = converted_mask.to(torch.int8)
+ if pre_defined_mask is None:
+ # wan case
+ output = block_sparse_sage2_attn_cuda(
+ query_hnd[:, :, :mask_map.video_token_num, :],
+ key_hnd[:, :, :mask_map.video_token_num, :],
+ value_hnd[:, :, :mask_map.video_token_num, :],
+ mask_id=converted_mask,
+ tensor_layout="HND",
+ )
+
+ # rearrange back to (s, h, d), we know that b = 1
+ output = rearrange(output, "b h s d -> s (b h) d", b=1)
+ return output
+
+ query_video = query_hnd[:, :, :mask_map.video_token_num, :]
+ key_video = key_hnd
+ value_video = value_hnd
+ kv_border = (pre_defined_mask[0].sum() + 63) // 64
+ converted_mask[:, :, :, kv_border:] = False
+ output_video = block_sparse_sage2_attn_cuda(
+ query_video,
+ key_video,
+ value_video,
+ mask_id=converted_mask[:, :, :mask_map.video_token_num // block_size, :].contiguous(),
+ tensor_layout="HND",
+ )
+
+ # rearrange back to (s, h, d), we know that b = 1
+ output_video = rearrange(output_video, "b h s d -> s (b h) d", b=1)
+
+ # gt = sparse_sageattn(
+ # query_video,
+ # key_video,
+ # value_video,
+ # mask_id=None,
+ # is_causal=False,
+ # tensor_layout="HND",
+ # )[0]
+
+
+
+ # import pdb; pdb.set_trace()
+
+ output_text = flashinfer.single_prefill_with_kv_cache(
+ q=query[mask_map.video_token_num:, :, :],
+ k=key[:pre_defined_mask[0].sum(), :, :],
+ v=value[:pre_defined_mask[0].sum(), :, :],
+ causal=False,
+ return_lse=False,
+ )
+
+ return torch.cat([output_video, output_text], dim=0)
+
+
+def FlashInferBackend(query, key, value, mask_map=None, pre_defined_mask=None, bsr_wrapper=None):
+ if pre_defined_mask is not None:
+ video_video_o, video_video_o_lse = bsr_wrapper.run(
+ query[:mask_map.video_token_num, :, :],
+ key[:mask_map.video_token_num, :, :],
+ value[:mask_map.video_token_num, :, :],
+ return_lse=True
+ )
+ # perform non-causal flashinfer on the text tokens
+ video_text_o, video_text_o_lse = flashinfer.single_prefill_with_kv_cache(
+ q=query[:mask_map.video_token_num, :, :],
+ k=key[mask_map.video_token_num:, :, :],
+ v=value[mask_map.video_token_num:, :, :],
+ causal=False,
+ return_lse=True,
+ custom_mask=pre_defined_mask[:mask_map.video_token_num, mask_map.video_token_num:]
+ )
+
+ # merge the two results
+ o_video, _ = flashinfer.merge_state(v_a=video_video_o, s_a=video_video_o_lse, v_b=video_text_o, s_b=video_text_o_lse)
+
+ o_text = flashinfer.single_prefill_with_kv_cache(
+ q=query[mask_map.video_token_num:, :, :],
+ k=key,
+ v=value,
+ causal=False,
+ return_lse=False,
+ custom_mask=pre_defined_mask[mask_map.video_token_num:, :]
+ )
+
+ return torch.cat([o_video, o_text], dim=0)
+ else:
+ o = bsr_wrapper.run(
+ query[:mask_map.video_token_num, :, :],
+ key[:mask_map.video_token_num, :, :],
+ value[:mask_map.video_token_num, :, :]
+ )
+ return o
+
+def RadialAttention(query, key, value, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_type=None, pre_defined_mask=None, use_sage_attention=False):
+ orig_seqlen, num_head, hidden_dim = query.shape
+
+ if sparsity_type == "dense":
+ video_mask = torch.ones((mask_map.video_token_num // block_size, mask_map.video_token_num // block_size), device=query.device, dtype=torch.bool)
+ else:
+ video_mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_type) if mask_map else None
+
+ backend = "sparse_sageattn" if use_sage_attention else "flashinfer"
+
+ if backend == "flashinfer":
+ video_mask = video_mask[:mask_map.video_token_num // block_size, :mask_map.video_token_num // block_size]
+ # perform block-sparse attention on the video tokens
+ workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8)
+ bsr_wrapper = flashinfer.BlockSparseAttentionWrapper(
+ workspace_buffer,
+ backend="fa2",
+ )
+
+ indptr = get_indptr_from_mask(video_mask, query)
+ indices = get_indices_from_mask(video_mask, query)
+
+ bsr_wrapper.plan(
+ indptr=indptr,
+ indices=indices,
+ M=mask_map.video_token_num,
+ N=mask_map.video_token_num,
+ R=block_size,
+ C=block_size,
+ num_qo_heads=num_head,
+ num_kv_heads=num_head,
+ head_dim=hidden_dim,
+ q_data_type=query.dtype,
+ kv_data_type=key.dtype,
+ o_data_type=query.dtype,
+ )
+
+ return FlashInferBackend(query, key, value, mask_map, pre_defined_mask, bsr_wrapper)
+ elif backend == "sparse_sageattn":
+ return SpargeSageAttnBackend(query, key, value, mask_map, video_mask, pre_defined_mask, block_size=block_size)
+
+if __name__ == "__main__":
+ query = torch.randn(1, 2, 4, 64).cuda()
+ # mask = torch.tensor([
+ # [True, False, True, False],
+ # [False, True, False, True],
+ # [True, False, False, True],
+ # [False, True, True, False]
+ # ], dtype=torch.bool)
+ # indices = get_indices_from_mask(mask, query)
+ # indptr = get_indptr_from_mask(mask, query)
+ # print("Indices: ", indices)
+ # print("Indptr: ", indptr)
+ video_token_num = 3840 * 30
+ num_frame = 30
+ token_per_frame = video_token_num / num_frame
+ padded_video_token_num = ((video_token_num + 1) // 128 + 1) * 128
+ print("padded: ", padded_video_token_num)
+ temporal_mask = gen_log_mask_shrinked(query, padded_video_token_num, video_token_num, num_frame, sparse_type="radial", decay_factor=1, model_type="hunyuan")
+ plt.figure(figsize=(10, 8), dpi=500)
+
+ plt.imshow(temporal_mask.cpu().numpy()[:, :], cmap='hot')
+ plt.colorbar()
+ plt.title("Temporal Mask")
+
+ plt.savefig("temporal_mask.png",
+ dpi=300,
+ bbox_inches='tight',
+ pad_inches=0.1)
+
+ plt.close()
+ # save the mask tensor
+ torch.save(temporal_mask, "temporal_mask.pt")
diff --git a/shared/radial_attention/inference.py b/shared/radial_attention/inference.py
new file mode 100644
index 000000000..04ef186ab
--- /dev/null
+++ b/shared/radial_attention/inference.py
@@ -0,0 +1,41 @@
+import torch
+from diffusers.models.attention_processor import Attention
+from diffusers.models.attention import AttentionModuleMixin
+from .attention import WanSparseAttnProcessor
+from .attn_mask import MaskMap
+
+def setup_radial_attention(
+ pipe,
+ height,
+ width,
+ num_frames,
+ dense_layers=0,
+ dense_timesteps=0,
+ decay_factor=1.0,
+ sparsity_type="radial",
+ use_sage_attention=False,
+):
+
+ num_frames = 1 + num_frames // (pipe.vae_scale_factor_temporal * pipe.transformer.config.patch_size[0])
+ mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
+ frame_size = int(height // mod_value) * int(width // mod_value)
+
+ AttnModule = WanSparseAttnProcessor
+ AttnModule.dense_block = dense_layers
+ AttnModule.dense_timestep = dense_timesteps
+ AttnModule.mask_map = MaskMap(video_token_num=frame_size * num_frames, num_frame=num_frames)
+ AttnModule.decay_factor = decay_factor
+ AttnModule.sparse_type = sparsity_type
+ AttnModule.use_sage_attention = use_sage_attention
+
+ print(f"Replacing Wan attention with {sparsity_type} attention")
+ print(f"video token num: {AttnModule.mask_map.video_token_num}, num frames: {num_frames}")
+ print(f"dense layers: {dense_layers}, dense timesteps: {dense_timesteps}, decay factor: {decay_factor}")
+
+ for layer_idx, m in enumerate(pipe.transformer.blocks):
+ m.attn1.processor.layer_idx = layer_idx
+
+ for _, m in pipe.transformer.named_modules():
+ if isinstance(m, AttentionModuleMixin) and hasattr(m.processor, 'layer_idx'):
+ layer_idx = m.processor.layer_idx
+ m.set_processor(AttnModule(layer_idx))
\ No newline at end of file
diff --git a/shared/radial_attention/sparse_transformer.py b/shared/radial_attention/sparse_transformer.py
new file mode 100644
index 000000000..96820d0f3
--- /dev/null
+++ b/shared/radial_attention/sparse_transformer.py
@@ -0,0 +1,424 @@
+# borrowed from svg-project/Sparse-VideoGen
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+
+from diffusers.models.transformers.transformer_wan import WanTransformerBlock, WanTransformer3DModel
+from diffusers import WanPipeline
+from diffusers.models.modeling_outputs import Transformer2DModelOutput
+from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+logger = logging.get_logger(__name__)
+from typing import Any, Callable, Dict, List, Optional, Union
+from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
+from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
+import torch.distributed as dist
+
+try:
+ from xfuser.core.distributed import (
+ get_ulysses_parallel_world_size,
+ get_ulysses_parallel_rank,
+ get_sp_group
+ )
+except:
+ pass
+
+class WanTransformerBlock_Sparse(WanTransformerBlock):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ numeral_timestep: Optional[int] = None,
+ ) -> torch.Tensor:
+ if temb.ndim == 4:
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, numerical_timestep=numeral_timestep)
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states).contiguous()
+
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+
+ return hidden_states
+
+class WanTransformer3DModel_Sparse(WanTransformer3DModel):
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ numeral_timestep: Optional[int] = None,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
+ if timestep.ndim == 2:
+ ts_seq_len = timestep.shape[1]
+ timestep = timestep.flatten() # batch_size * seq_len
+ else:
+ ts_seq_len = None
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
+ )
+
+ if ts_seq_len is not None:
+ # batch_size, seq_len, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+ else:
+ # batch_size, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ if dist.is_initialized() and get_ulysses_parallel_world_size() > 1:
+ # split video latents on dim TS
+ hidden_states = torch.chunk(hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()]
+ rotary_emb = (
+ torch.chunk(rotary_emb[0], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()],
+ torch.chunk(rotary_emb[1], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()],
+ )
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, numeral_timestep=numeral_timestep
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ numeral_timestep=numeral_timestep,
+ )
+
+ # 5. Output norm, projection & unpatchify
+ if temb.ndim == 3:
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift = shift.squeeze(2)
+ scale = scale.squeeze(2)
+ else:
+ # batch_size, inner_dim
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+
+ if dist.is_initialized() and get_ulysses_parallel_world_size() > 1:
+ hidden_states = get_sp_group().all_gather(hidden_states, dim=-2)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+class WanPipeline_Sparse(WanPipeline):
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
+ The dtype to use for the torch.amp.autocast.
+
+ Examples:
+
+ Returns:
+ [`~WanPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ numeral_timestep=i,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ numeral_timestep=i,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return WanPipelineOutput(frames=video)
+
+def replace_sparse_forward():
+ WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward
+ WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward
+ WanPipeline.__call__ = WanPipeline_Sparse.__call__
\ No newline at end of file
diff --git a/shared/radial_attention/utils.py b/shared/radial_attention/utils.py
new file mode 100644
index 000000000..92b3c3e13
--- /dev/null
+++ b/shared/radial_attention/utils.py
@@ -0,0 +1,16 @@
+import os
+import random
+import numpy as np
+import torch
+
+def set_seed(seed):
+ """
+ Set the random seed for reproducibility.
+ """
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
\ No newline at end of file
From a4d62fc1cd3e37b5f4dcb27aa189a4c155ca715e Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sat, 27 Sep 2025 16:30:13 +0200
Subject: [PATCH 047/155] removed test line that cause sage crash
---
shared/attention.py | 2 --
1 file changed, 2 deletions(-)
diff --git a/shared/attention.py b/shared/attention.py
index efaa22370..c31087c24 100644
--- a/shared/attention.py
+++ b/shared/attention.py
@@ -67,8 +67,6 @@ def sageattn2_wrapper(
return o
-from sageattn import sageattn_blackwell as sageattn3
-
try:
from sageattn import sageattn_blackwell as sageattn3
except ImportError:
From 63db8448f00a718e68a0605214bd044b08cbfbe9 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 18:13:09 +0200
Subject: [PATCH 048/155] Animate upgraded
---
models/wan/any2video.py | 35 ++-
models/wan/lynx/__init__.py | 0
models/wan/lynx/flash_attention.py | 60 ++++
models/wan/lynx/full_attention_processor.py | 305 +++++++++++++++++++
models/wan/lynx/inference_utils.py | 39 +++
models/wan/lynx/light_attention_processor.py | 194 ++++++++++++
models/wan/lynx/model_utils_wan.py | 264 ++++++++++++++++
models/wan/lynx/navit_utils.py | 16 +
models/wan/lynx/resampler.py | 185 +++++++++++
models/wan/modules/model.py | 78 ++++-
models/wan/wan_handler.py | 48 ++-
preprocessing/arc/face_encoder.py | 97 ++++++
preprocessing/arc/face_utils.py | 63 ++++
requirements.txt | 2 +
wgp.py | 108 +++++--
15 files changed, 1456 insertions(+), 38 deletions(-)
create mode 100644 models/wan/lynx/__init__.py
create mode 100644 models/wan/lynx/flash_attention.py
create mode 100644 models/wan/lynx/full_attention_processor.py
create mode 100644 models/wan/lynx/inference_utils.py
create mode 100644 models/wan/lynx/light_attention_processor.py
create mode 100644 models/wan/lynx/model_utils_wan.py
create mode 100644 models/wan/lynx/navit_utils.py
create mode 100644 models/wan/lynx/resampler.py
create mode 100644 preprocessing/arc/face_encoder.py
create mode 100644 preprocessing/arc/face_utils.py
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 7dc4b19c8..b0fc24659 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -354,6 +354,7 @@ def generate(self,
pre_video_frame = None,
video_prompt_type= "",
original_input_ref_images = [],
+ face_arc_embeds = None,
**bbargs
):
@@ -428,6 +429,7 @@ def generate(self,
multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]
infinitetalk = model_type in ["infinitetalk"]
standin = model_type in ["standin", "vace_standin_14B"]
+ lynx = model_type in ["lynx_lite", "lynx_full", "vace_lynx_lite_14B", "vace_lynx_full_14B"]
recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
lucy_edit= model_type in ["lucy_edit"]
@@ -535,6 +537,7 @@ def generate(self,
pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0)
input_frames = input_frames * input_masks
if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames
+ # input_frames = input_frames[:, :1].expand(-1, input_frames.shape[1], -1, -1)
if prefix_frames_count > 0:
input_frames[:, :prefix_frames_count] = input_video
input_masks[:, :prefix_frames_count] = 1
@@ -549,7 +552,8 @@ def generate(self,
clip_image_start = image_ref.squeeze(1)
lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1)
y = torch.concat([msk, lat_y])
- kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)})
+ kwargs.update({ 'y': y, 'pose_latents': pose_latents})
+ face_pixel_values = input_faces.unsqueeze(0)
lat_y = msk = msk_control = msk_ref = pose_pixels = None
ref_images_before = True
ref_images_count = 1
@@ -699,6 +703,23 @@ def generate(self,
kwargs["freqs"] = freqs
+ # Lynx
+ if lynx:
+ from .lynx.resampler import Resampler
+ from accelerate import init_empty_weights
+ lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"]
+ with init_empty_weights():
+ arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 )
+ offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors"))
+ arc_resampler.to(self.device)
+ arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float)
+ ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype)
+ ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype)
+ arc_resampler = None
+ gc.collect()
+ torch.cuda.empty_cache()
+ kwargs["lynx_ip_scale"] = 1.
+
#Standin
if standin:
from preprocessing.face_preprocessor import FaceProcessor
@@ -848,6 +869,18 @@ def clear():
"context" : [context, context_null, context_null],
"audio_scale": [audio_scale, None, None ]
}
+ elif animate:
+ gen_args = {
+ "x" : [latent_model_input, latent_model_input],
+ "context" : [context, context_null],
+ "face_pixel_values": [face_pixel_values, None]
+ }
+ elif lynx:
+ gen_args = {
+ "x" : [latent_model_input, latent_model_input],
+ "context" : [context, context_null],
+ "lynx_ip_embeds": [ip_hidden_states, ip_hidden_states_uncond]
+ }
elif multitalk and audio_proj != None:
if guide_scale == 1:
gen_args = {
diff --git a/models/wan/lynx/__init__.py b/models/wan/lynx/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/models/wan/lynx/flash_attention.py b/models/wan/lynx/flash_attention.py
new file mode 100644
index 000000000..3bb807db8
--- /dev/null
+++ b/models/wan/lynx/flash_attention.py
@@ -0,0 +1,60 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+try:
+ from flash_attn_interface import flash_attn_varlen_func # flash attn 3
+except:
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func # flash attn 2
+
+def flash_attention(query, key, value, q_lens, kv_lens, causal=False):
+ """
+ Args:
+ query, key, value: [B, H, T, D_h]
+ q_lens: list[int] per sequence query length
+ kv_lens: list[int] per sequence key/value length
+ causal: whether to use causal mask
+
+ Returns:
+ output: [B, H, T_q, D_h]
+ """
+ B, H, T_q, D_h = query.shape
+ T_k = key.shape[2]
+ device = query.device
+
+ # Flatten: [B, H, T, D] -> [total_tokens, H, D]
+ q = query.permute(0, 2, 1, 3).reshape(B * T_q, H, D_h)
+ k = key.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h)
+ v = value.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h)
+
+ # Prepare cu_seqlens: prefix sum
+ q_lens_tensor = torch.tensor(q_lens, device=device, dtype=torch.int32)
+ kv_lens_tensor = torch.tensor(kv_lens, device=device, dtype=torch.int32)
+
+ cu_seqlens_q = torch.zeros(len(q_lens_tensor) + 1, device=device, dtype=torch.int32)
+ cu_seqlens_k = torch.zeros(len(kv_lens_tensor) + 1, device=device, dtype=torch.int32)
+
+ cu_seqlens_q[1:] = torch.cumsum(q_lens_tensor, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(kv_lens_tensor, dim=0)
+
+ max_seqlen_q = int(q_lens_tensor.max().item())
+ max_seqlen_k = int(kv_lens_tensor.max().item())
+
+ # Call FlashAttention varlen kernel
+ out = flash_attn_varlen_func(
+ q,
+ k,
+ v,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ causal=causal
+ )
+ if not torch.is_tensor(out): # flash attn 3
+ out = out[0]
+
+ # Restore shape: [total_q, H, D_h] -> [B, H, T_q, D_h]
+ out = out.view(B, T_q, H, D_h).permute(0, 2, 1, 3).contiguous()
+
+ return out
\ No newline at end of file
diff --git a/models/wan/lynx/full_attention_processor.py b/models/wan/lynx/full_attention_processor.py
new file mode 100644
index 000000000..fd0edb7a8
--- /dev/null
+++ b/models/wan/lynx/full_attention_processor.py
@@ -0,0 +1,305 @@
+# Copyright (c) 2025 The Wan Team and The HuggingFace Team.
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+#
+# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
+#
+# Original file was released under Apache License 2.0, with the full license text
+# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt.
+#
+# This modified file is released under the same license.
+
+
+from typing import Optional, List
+import torch
+import torch.nn as nn
+from diffusers.models.attention_processor import Attention
+from modules.common.flash_attention import flash_attention
+from modules.common.navit_utils import vector_to_list, list_to_vector, merge_token_lists
+
+def new_forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+Attention.forward = new_forward
+
+class WanIPAttnProcessor(nn.Module):
+ def __init__(
+ self,
+ cross_attention_dim: int,
+ dim: int,
+ n_registers: int,
+ bias: bool,
+ ):
+ super().__init__()
+ self.to_k_ip = nn.Linear(cross_attention_dim, dim, bias=bias)
+ self.to_v_ip = nn.Linear(cross_attention_dim, dim, bias=bias)
+ if n_registers > 0:
+ self.registers = nn.Parameter(torch.randn(1, n_registers, cross_attention_dim) / dim**0.5)
+ else:
+ self.registers = None
+
+ def forward(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ ip_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ q_lens: List[int] = None,
+ kv_lens: List[int] = None,
+ ip_lens: List[int] = None,
+ ip_scale: float = 1.0,
+ ):
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ encoder_hidden_states_img = encoder_hidden_states[:, :257]
+ encoder_hidden_states = encoder_hidden_states[:, 257:]
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if ip_hidden_states is not None:
+
+ if self.registers is not None:
+ ip_hidden_states_list = vector_to_list(ip_hidden_states, ip_lens, 1)
+ ip_hidden_states_list = merge_token_lists(ip_hidden_states_list, [self.registers] * len(ip_hidden_states_list), 1)
+ ip_hidden_states, ip_lens = list_to_vector(ip_hidden_states_list, 1)
+
+ ip_query = query
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+ if attn.norm_q is not None:
+ ip_query = attn.norm_q(ip_query)
+ if attn.norm_k is not None:
+ ip_key = attn.norm_k(ip_key)
+ ip_query = ip_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ ip_key = ip_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ ip_value = ip_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ ip_hidden_states = flash_attention(
+ ip_query, ip_key, ip_value,
+ q_lens=q_lens,
+ kv_lens=ip_lens
+ )
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).flatten(2, 3)
+ ip_hidden_states = ip_hidden_states.type_as(ip_query)
+
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, rotary_emb)
+ key = apply_rotary_emb(key, rotary_emb)
+
+ hidden_states = flash_attention(
+ query, key, value,
+ q_lens=q_lens,
+ kv_lens=kv_lens,
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if ip_hidden_states is not None:
+ hidden_states = hidden_states + ip_scale * ip_hidden_states
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+def register_ip_adapter(
+ model,
+ cross_attention_dim=None,
+ n_registers=0,
+ init_method="zero",
+ dtype=torch.float32,
+):
+ attn_procs = {}
+ transformer_sd = model.state_dict()
+ for layer_idx, block in enumerate(model.blocks):
+ name = f"blocks.{layer_idx}.attn2.processor"
+ layer_name = name.split(".processor")[0]
+ dim = transformer_sd[layer_name + ".to_k.weight"].shape[1]
+ attn_procs[name] = WanIPAttnProcessor(
+ cross_attention_dim=dim if cross_attention_dim is None else cross_attention_dim,
+ dim=dim,
+ n_registers=n_registers,
+ bias=True,
+ )
+ if init_method == "zero":
+ torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight)
+ torch.nn.init.zeros_(attn_procs[name].to_k_ip.bias)
+ torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight)
+ torch.nn.init.zeros_(attn_procs[name].to_v_ip.bias)
+ elif init_method == "clone":
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"],
+ "to_k_ip.bias": transformer_sd[layer_name + ".to_k.bias"],
+ "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"],
+ "to_v_ip.bias": transformer_sd[layer_name + ".to_v.bias"],
+ }
+ attn_procs[name].load_state_dict(weights)
+ elif init_method == "random":
+ pass
+ else:
+ raise ValueError(f"{init_method} is not supported.")
+ block.attn2.processor = attn_procs[name]
+ ip_layers = torch.nn.ModuleList(attn_procs.values())
+ ip_layers.to(device=model.device, dtype=dtype)
+ return model, ip_layers
+
+
+class WanRefAttnProcessor(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ bias: bool,
+ ):
+ super().__init__()
+ self.to_k_ref = nn.Linear(dim, dim, bias=bias)
+ self.to_v_ref = nn.Linear(dim, dim, bias=bias)
+
+ def forward(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ q_lens: List[int] = None,
+ kv_lens: List[int] = None,
+ ref_feature: Optional[tuple] = None,
+ ref_scale: float = 1.0,
+ ):
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ encoder_hidden_states_img = encoder_hidden_states[:, :257]
+ encoder_hidden_states = encoder_hidden_states[:, 257:]
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if ref_feature is None:
+ ref_hidden_states = None
+ else:
+ ref_hidden_states, ref_lens = ref_feature
+ ref_query = query
+ ref_key = self.to_k_ref(ref_hidden_states)
+ ref_value = self.to_v_ref(ref_hidden_states)
+ if attn.norm_q is not None:
+ ref_query = attn.norm_q(ref_query)
+ if attn.norm_k is not None:
+ ref_key = attn.norm_k(ref_key)
+ ref_query = ref_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ ref_key = ref_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ ref_value = ref_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ ref_hidden_states = flash_attention(
+ ref_query, ref_key, ref_value,
+ q_lens=q_lens,
+ kv_lens=ref_lens
+ )
+ ref_hidden_states = ref_hidden_states.transpose(1, 2).flatten(2, 3)
+ ref_hidden_states = ref_hidden_states.type_as(ref_query)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, rotary_emb)
+ key = apply_rotary_emb(key, rotary_emb)
+
+ hidden_states = flash_attention(
+ query, key, value,
+ q_lens=q_lens,
+ kv_lens=kv_lens,
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if ref_hidden_states is not None:
+ hidden_states = hidden_states + ref_scale * ref_hidden_states
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+def register_ref_adapter(
+ model,
+ init_method="zero",
+ dtype=torch.float32,
+):
+ attn_procs = {}
+ transformer_sd = model.state_dict()
+ for layer_idx, block in enumerate(model.blocks):
+ name = f"blocks.{layer_idx}.attn1.processor"
+ layer_name = name.split(".processor")[0]
+ dim = transformer_sd[layer_name + ".to_k.weight"].shape[0]
+ attn_procs[name] = WanRefAttnProcessor(
+ dim=dim,
+ bias=True,
+ )
+ if init_method == "zero":
+ torch.nn.init.zeros_(attn_procs[name].to_k_ref.weight)
+ torch.nn.init.zeros_(attn_procs[name].to_k_ref.bias)
+ torch.nn.init.zeros_(attn_procs[name].to_v_ref.weight)
+ torch.nn.init.zeros_(attn_procs[name].to_v_ref.bias)
+ elif init_method == "clone":
+ weights = {
+ "to_k_ref.weight": transformer_sd[layer_name + ".to_k.weight"],
+ "to_k_ref.bias": transformer_sd[layer_name + ".to_k.bias"],
+ "to_v_ref.weight": transformer_sd[layer_name + ".to_v.weight"],
+ "to_v_ref.bias": transformer_sd[layer_name + ".to_v.bias"],
+ }
+ attn_procs[name].load_state_dict(weights)
+ elif init_method == "random":
+ pass
+ else:
+ raise ValueError(f"{init_method} is not supported.")
+ block.attn1.processor = attn_procs[name]
+ ref_layers = torch.nn.ModuleList(attn_procs.values())
+ ref_layers.to(device=model.device, dtype=dtype)
+ return model, ref_layers
diff --git a/models/wan/lynx/inference_utils.py b/models/wan/lynx/inference_utils.py
new file mode 100644
index 000000000..c1129f3bf
--- /dev/null
+++ b/models/wan/lynx/inference_utils.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Optional, Union
+
+import torch
+import numpy as np
+
+from PIL import Image
+from dataclasses import dataclass, field
+
+
+dtype_mapping = {
+ "bf16": torch.bfloat16,
+ "fp16": torch.float16,
+ "fp32": torch.float32
+}
+
+
+@dataclass
+class SubjectInfo:
+ name: str = ""
+ image_pil: Optional[Image.Image] = None
+ landmarks: Optional[Union[np.ndarray, torch.Tensor]] = None
+ face_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None
+
+
+@dataclass
+class VideoStyleInfo: # key names should match those used in style.yaml file
+ style_name: str = 'none'
+ num_frames: int = 81
+ seed: int = -1
+ guidance_scale: float = 5.0
+ guidance_scale_i: float = 2.0
+ num_inference_steps: int = 50
+ width: int = 832
+ height: int = 480
+ prompt: str = ''
+ negative_prompt: str = ''
\ No newline at end of file
diff --git a/models/wan/lynx/light_attention_processor.py b/models/wan/lynx/light_attention_processor.py
new file mode 100644
index 000000000..a3c7516b1
--- /dev/null
+++ b/models/wan/lynx/light_attention_processor.py
@@ -0,0 +1,194 @@
+# Copyright (c) 2025 The Wan Team and The HuggingFace Team.
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+#
+# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
+#
+# Original file was released under Apache License 2.0, with the full license text
+# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt.
+#
+# This modified file is released under the same license.
+
+from typing import Tuple, Union, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.models.normalization import RMSNorm
+
+
+def register_ip_adapter_wan(
+ model,
+ hidden_size=5120,
+ cross_attention_dim=2048,
+ dtype=torch.float32,
+ init_method="zero",
+ layers=None
+):
+ attn_procs = {}
+ transformer_sd = model.state_dict()
+
+ if layers is None:
+ layers = list(range(0, len(model.blocks)))
+ elif isinstance(layers, int): # Only interval provided
+ layers = list(range(0, len(model.blocks), layers))
+
+ for i, block in enumerate(model.blocks):
+ if i not in layers:
+ continue
+
+ name = f"blocks.{i}.attn2.processor"
+ attn_procs[name] = IPAWanAttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim
+ )
+
+ if init_method == "zero":
+ torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight)
+ torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight)
+ elif init_method == "clone":
+ layer_name = name.split(".processor")[0]
+ weights = {
+ "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"],
+ "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"],
+ }
+ attn_procs[name].load_state_dict(weights)
+ else:
+ raise ValueError(f"{init_method} is not supported.")
+
+ block.attn2.processor = attn_procs[name]
+
+ ip_layers = torch.nn.ModuleList(attn_procs.values())
+ ip_layers.to(model.device, dtype=dtype)
+
+ return model, ip_layers
+
+
+class IPAWanAttnProcessor2_0(torch.nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ cross_attention_dim=None,
+ scale=1.0,
+ bias=False,
+ ):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("IPAWanAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias)
+
+ torch.nn.init.zeros_(self.to_k_ip.weight)
+ torch.nn.init.zeros_(self.to_v_ip.weight)
+ if bias:
+ torch.nn.init.zeros_(self.to_k_ip.bias)
+ torch.nn.init.zeros_(self.to_v_ip.bias)
+
+ self.norm_rms_k = RMSNorm(hidden_size, eps=1e-5, elementwise_affine=False)
+
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.Tensor,
+ image_embed: torch.Tensor = None,
+ ip_scale: float = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ encoder_hidden_states_img = encoder_hidden_states[:, :257]
+ encoder_hidden_states = encoder_hidden_states[:, 257:]
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ # =============================================================
+ batch_size = image_embed.size(0)
+
+ ip_hidden_states = image_embed
+ ip_query = query # attn.to_q(hidden_states.clone())
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ if attn.norm_q is not None:
+ ip_query = attn.norm_q(ip_query)
+ ip_key = self.norm_rms_k(ip_key)
+
+ ip_inner_dim = ip_key.shape[-1]
+ ip_head_dim = ip_inner_dim // attn.heads
+
+ ip_query = ip_query.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2)
+ ip_key = ip_key.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2)
+
+ ip_hidden_states = F.scaled_dot_product_attention(
+ ip_query, ip_key, ip_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * ip_head_dim)
+ ip_hidden_states = ip_hidden_states.to(ip_query.dtype)
+ # ===========================================================================
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb_inner(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+ query = apply_rotary_emb_inner(query, rotary_emb)
+ key = apply_rotary_emb_inner(key, rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ hidden_states_img = F.scaled_dot_product_attention(
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ # Add IPA residual
+ ip_scale = ip_scale or self.scale
+ hidden_states = hidden_states + ip_scale * ip_hidden_states
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
\ No newline at end of file
diff --git a/models/wan/lynx/model_utils_wan.py b/models/wan/lynx/model_utils_wan.py
new file mode 100644
index 000000000..2fd117046
--- /dev/null
+++ b/models/wan/lynx/model_utils_wan.py
@@ -0,0 +1,264 @@
+# Copyright (c) 2025 The Wan Team and The HuggingFace Team.
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+#
+# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
+#
+# Original file was released under Apache License 2.0, with the full license text
+# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt.
+#
+# This modified file is released under the same license.
+
+import os
+import math
+import torch
+
+from typing import List, Optional, Tuple, Union
+
+from diffusers.models.embeddings import get_3d_rotary_pos_embed
+from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
+
+from diffusers import FlowMatchEulerDiscreteScheduler
+
+""" Utility functions for Wan2.1-T2I model """
+
+
+def cal_mean_and_std(vae, device, dtype):
+ # https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py#L483
+
+ # latents_mean
+ latents_mean = torch.tensor(vae.config.latents_mean).view(
+ 1, vae.config.z_dim, 1, 1, 1).to(device, dtype)
+
+ # latents_std
+ latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(
+ 1, vae.config.z_dim, 1, 1, 1).to(device, dtype)
+
+ return latents_mean, latents_std
+
+
+def encode_video(vae, video):
+ # Input video is formatted as [B, F, C, H, W]
+ video = video.to(vae.device, dtype=vae.dtype)
+ video = video.unsqueeze(0) if len(video.shape) == 4 else video
+
+ # VAE takes [B, C, F, H, W]
+ video = video.permute(0, 2, 1, 3, 4)
+ latent_dist = vae.encode(video).latent_dist
+
+ mean, std = cal_mean_and_std(vae, vae.device, vae.dtype)
+
+ # NOTE: this implementation is weird but I still borrow it from the official
+ # codes. A standard normailization should be (x - mean) / std, but the std
+ # here is computed by `1 / vae.config.latents_std` from cal_mean_and_std().
+
+ # Sample latent with scaling
+ video_latent = (latent_dist.sample() - mean) * std
+
+ # Return as [B, C, F, H, W] as required by Wan
+ return video_latent
+
+
+def decode_latent(vae, latent):
+ # Input latent is formatted as [B, C, F, H, W]
+ latent = latent.to(vae.device, dtype=vae.dtype)
+ latent = latent.unsqueeze(0) if len(latent.shape) == 4 else latent
+
+ mean, std = cal_mean_and_std(vae, latent.device, latent.dtype)
+ latent = latent / std + mean
+
+ # VAE takes [B, C, F, H, W]
+ frames = vae.decode(latent).sample
+
+ # Return as [B, F, C, H, W]
+ frames = frames.permute(0, 2, 1, 3, 4).float()
+
+ return frames
+
+
+def encode_prompt(
+ tokenizer,
+ text_encoder,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ text_input_ids=None,
+ prompt_attention_mask=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_attention_mask = text_inputs.attention_mask
+ else:
+ if text_input_ids is None:
+ raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
+
+ seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
+ prompt_embeds = text_encoder(text_input_ids.to(device), prompt_attention_mask.to(device)).last_hidden_state
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, prompt_attention_mask
+
+
+def compute_prompt_embeddings(
+ tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
+):
+ if requires_grad:
+ prompt_embeds = encode_prompt(
+ tokenizer,
+ text_encoder,
+ prompt,
+ num_videos_per_prompt=1,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ else:
+ with torch.no_grad():
+ prompt_embeds = encode_prompt(
+ tokenizer,
+ text_encoder,
+ prompt,
+ num_videos_per_prompt=1,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ return prompt_embeds[0]
+
+
+def prepare_rotary_positional_embeddings(
+ height: int,
+ width: int,
+ num_frames: int,
+ vae_scale_factor_spatial: int = 8,
+ patch_size: int = 2,
+ attention_head_dim: int = 64,
+ device: Optional[torch.device] = None,
+ base_height: int = 480,
+ base_width: int = 720,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ grid_height = height // (vae_scale_factor_spatial * patch_size)
+ grid_width = width // (vae_scale_factor_spatial * patch_size)
+ base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
+ base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
+
+ grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
+ embed_dim=attention_head_dim,
+ crops_coords=grid_crops_coords,
+ grid_size=(grid_height, grid_width),
+ temporal_size=num_frames,
+ )
+
+ freqs_cos = freqs_cos.to(device=device)
+ freqs_sin = freqs_sin.to(device=device)
+ return freqs_cos, freqs_sin
+
+
+def prepare_sigmas(
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ sigmas: torch.Tensor,
+ batch_size: int,
+ num_train_timesteps: int,
+ flow_weighting_scheme: str = "none",
+ flow_logit_mean: float = 0.0,
+ flow_logit_std: float = 1.0,
+ flow_mode_scale: float = 1.29,
+ device: torch.device = torch.device("cpu"),
+ generator: Optional[torch.Generator] = None,
+) -> torch.Tensor:
+ if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
+ weights = compute_density_for_timestep_sampling(
+ weighting_scheme=flow_weighting_scheme,
+ batch_size=batch_size,
+ logit_mean=flow_logit_mean,
+ logit_std=flow_logit_std,
+ mode_scale=flow_mode_scale,
+ device=device,
+ generator=generator,
+ )
+ indices = (weights * num_train_timesteps).long()
+ else:
+ raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
+
+ return sigmas[indices]
+
+
+def compute_density_for_timestep_sampling(
+ weighting_scheme: str,
+ batch_size: int,
+ logit_mean: float = None,
+ logit_std: float = None,
+ mode_scale: float = None,
+ device: torch.device = torch.device("cpu"),
+ generator: Optional[torch.Generator] = None,
+) -> torch.Tensor:
+ r"""
+ Compute the density for sampling the timesteps when doing SD3 training.
+
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
+
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
+ """
+ if weighting_scheme == "logit_normal":
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
+ u = torch.nn.functional.sigmoid(u)
+ elif weighting_scheme == "mode":
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
+ else:
+ u = torch.rand(size=(batch_size,), device=device, generator=generator)
+ return u
+
+
+def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor:
+ assert len(tensor.shape) <= ndim
+ return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape)))
+
+
+def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+ r"""Forward process of flow matching."""
+ return (1.0 - t) * x0 + t * n
+
+
+def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
+ r"""Loss target for flow matching."""
+ return n - x0
+
+
+def prepare_loss_weights(
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ sigmas: Optional[torch.Tensor] = None,
+ flow_weighting_scheme: str = "none",
+
+) -> torch.Tensor:
+
+ assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
+
+ from diffusers.training_utils import compute_loss_weighting_for_sd3
+
+ return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
\ No newline at end of file
diff --git a/models/wan/lynx/navit_utils.py b/models/wan/lynx/navit_utils.py
new file mode 100644
index 000000000..593e7f0d1
--- /dev/null
+++ b/models/wan/lynx/navit_utils.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+
+def vector_to_list(tensor, lens, dim):
+ return list(torch.split(tensor, lens, dim=dim))
+
+def list_to_vector(tensor_list, dim):
+ lens = [tensor.shape[dim] for tensor in tensor_list]
+ tensor = torch.cat(tensor_list, dim)
+ return tensor, lens
+
+def merge_token_lists(list1, list2, dim):
+ assert(len(list1) == len(list2))
+ return [torch.cat((t1, t2), dim) for t1, t2 in zip(list1, list2)]
\ No newline at end of file
diff --git a/models/wan/lynx/resampler.py b/models/wan/lynx/resampler.py
new file mode 100644
index 000000000..e225a062b
--- /dev/null
+++ b/models/wan/lynx/resampler.py
@@ -0,0 +1,185 @@
+
+# Copyright (c) 2022 Phil Wang
+# Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt.
+# Copyright (c) 2023 Tencent AI Lab
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+
+# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
+# SPDX-License-Identifier: Apache-2.0
+
+# Original file (IP-Adapter) was released under Apache License 2.0, with the full license text
+# available at https://github.com/tencent-ailab/IP-Adapter/blob/main/LICENSE.
+
+# Original file (open_flamingo and flamingo-pytorch) was released under MIT License:
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import math
+
+import torch
+import torch.nn as nn
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ max_seq_len: int = 257, # CLIP tokens + CLS token
+ apply_pos_emb: bool = False,
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
+ **kwargs,
+ ):
+ super().__init__()
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.to_latents_from_mean_pooled_seq = (
+ nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * num_latents_mean_pooled),
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
+ )
+ if num_latents_mean_pooled > 0
+ else None
+ )
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ self.pretrained = kwargs.pop("pretrained", None)
+ self.freeze = kwargs.pop("freeze", False)
+ self.skip_last_layer = kwargs.pop("skip_last_layer", False)
+
+ if self.freeze and self.pretrained is None:
+ raise ValueError("Trying to freeze the model but no pretrained provided!")
+
+
+ def forward(self, x):
+ if self.pos_emb is not None:
+ n, device = x.shape[1], x.device
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
+ x = x + pos_emb
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ if self.to_latents_from_mean_pooled_seq:
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index 82a313cec..42af87737 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -355,13 +355,38 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks
class WanT2VCrossAttention(WanSelfAttention):
- def forward(self, xlist, context, grid_sizes, *args, **kwargs):
+ def forward(self, xlist, context, grid_sizes, lynx_ip_embeds = None, lynx_ip_scale = 0, *args, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
"""
- x, _ = self.text_cross_attention( xlist, context)
+ x, q = self.text_cross_attention( xlist, context, return_q=lynx_ip_embeds is not None)
+ if lynx_ip_embeds is not None and self.to_k_ip is not None:
+ if self.registers is not None:
+ from ..lynx.navit_utils import vector_to_list, merge_token_lists, list_to_vector
+ ip_hidden_states_list = vector_to_list(lynx_ip_embeds, lynx_ip_embeds.shape[1], 1)
+ ip_hidden_states_list = merge_token_lists(ip_hidden_states_list, [self.registers] * len(ip_hidden_states_list), 1)
+ lynx_ip_embeds, ip_lens = list_to_vector(ip_hidden_states_list, 1)
+ ip_hidden_states_list = None
+ ip_key = self.to_k_ip(lynx_ip_embeds)
+ ip_value = self.to_v_ip(lynx_ip_embeds)
+ # if self.norm_q is not None: ip_query = self.norm_q(q)
+ if self.norm_rms_k is None:
+ ip_key = self.norm_k(ip_key)
+ else:
+ ip_key = self.norm_rms_k(ip_key)
+ ip_inner_dim = ip_key.shape[-1]
+ ip_head_dim = ip_inner_dim // self.num_heads
+ batch_size = q.shape[0]
+ # ip_query = ip_query.view(batch_size, -1, attn.heads, ip_head_dim)
+ ip_key = ip_key.view(batch_size, -1, self.num_heads, ip_head_dim)
+ ip_value = ip_value.view(batch_size, -1, self.num_heads, ip_head_dim)
+ qkv_list = [q, ip_key, ip_value]
+ del q, ip_key, ip_value
+ ip_hidden_states = pay_attention(qkv_list).reshape(*x.shape)
+ x.add_(ip_hidden_states, alpha= lynx_ip_scale)
+
x = self.o(x)
return x
@@ -382,7 +407,7 @@ def __init__(self,
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
- def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ):
+ def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens, *args, **kwargs ):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@@ -510,6 +535,8 @@ def forward(
ref_images_count=0,
standin_phase=-1,
motion_vec = None,
+ lynx_ip_embeds = None,
+ lynx_ip_scale = 0,
):
r"""
Args:
@@ -568,7 +595,7 @@ def forward(
y = y.to(attention_dtype)
ylist= [y]
del y
- x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype)
+ x += self.cross_attn(ylist, context, grid_sizes, audio_proj = audio_proj, audio_scale = audio_scale, audio_context_lens = audio_context_lens, lynx_ip_embeds=lynx_ip_embeds, lynx_ip_scale=lynx_ip_scale).to(dtype)
if multitalk_audio != None:
# cross attn of multitalk audio
@@ -914,6 +941,7 @@ def __init__(self,
norm_output_audio=True,
standin= False,
motion_encoder_dim=0,
+ lynx=None,
):
super().__init__()
@@ -1044,6 +1072,32 @@ def __init__(self,
block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128)
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
+ if lynx is not None:
+ from ..lynx.light_attention_processor import IPAWanAttnProcessor2_0
+ lynx_full = lynx=="full"
+ if lynx_full:
+ lynx_cross_dim = 5120
+ lynx_layers = len(self.blocks)
+ else:
+ lynx_cross_dim = 2048
+ lynx_layers = 20
+ for i, block in enumerate(self.blocks):
+ if i < lynx_layers:
+ attn = IPAWanAttnProcessor2_0(hidden_size=dim, cross_attention_dim = lynx_cross_dim,)
+ block.cross_attn.to_k_ip = attn.to_k_ip
+ block.cross_attn.to_v_ip = attn.to_v_ip
+ else:
+ block.cross_attn.to_k_ip = None
+ block.cross_attn.to_v_ip = None
+ if lynx_full:
+ block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5)
+ block.cross_attn.norm_rms_k = None
+ else:
+ block.cross_attn.registers = None
+ block.cross_attn.norm_rms_k = attn.norm_rms_k
+
+
+
if animate:
self.pose_patch_embedding = nn.Conv3d(
16, dim, kernel_size=patch_size, stride=patch_size
@@ -1247,6 +1301,8 @@ def forward(
standin_ref = None,
pose_latents=None,
face_pixel_values=None,
+ lynx_ip_embeds = None,
+ lynx_ip_scale = 0,
):
# patch_dtype = self.patch_embedding.weight.dtype
@@ -1286,11 +1342,11 @@ def forward(
y = None
motion_vec_list = []
- for i, x in enumerate(x_list):
+ for i, (x, one_face_pixel_values) in enumerate(zip(x_list, face_pixel_values)):
# animate embeddings
motion_vec = None
if pose_latents is not None:
- x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values)
+ x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(one_face_pixel_values) if one_face_pixel_values is None else one_face_pixel_values)
motion_vec_list.append(motion_vec)
if chipmunk:
x = x.unsqueeze(-1)
@@ -1330,6 +1386,7 @@ def forward(
audio_proj=audio_proj,
audio_context_lens=audio_context_lens,
ref_images_count=ref_images_count,
+ lynx_ip_scale= lynx_ip_scale,
)
_flag_df = t.dim() == 2
@@ -1348,6 +1405,11 @@ def forward(
standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) )
standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype)
standin_e = standin_ref = None
+
+ if lynx_ip_embeds is None:
+ lynx_ip_embeds_list = [None] * len(x_list)
+ else:
+ lynx_ip_embeds_list = lynx_ip_embeds
if self.inject_sample_info and fps!=None:
fps = torch.tensor(fps, dtype=torch.long, device=device)
@@ -1498,9 +1560,9 @@ def forward(
continue
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
else:
- for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)):
+ for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec, lynx_ip_embeds) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list, lynx_ip_embeds_list)):
if should_calc:
- x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs)
+ x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec, lynx_ip_embeds= lynx_ip_embeds, **kwargs)
del x
context = hints = audio_embedding = None
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 1a30fffbc..e1829f03a 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -17,6 +17,9 @@ def test_multitalk(base_model_type):
def test_standin(base_model_type):
return base_model_type in ["standin", "vace_standin_14B"]
+def test_lynx(base_model_type):
+ return base_model_type in ["lynx_lite", "vace_lynx_lite_14B", "lynx_full", "vace_lynx_full_14B"]
+
def test_wan_5B(base_model_type):
return base_model_type in ["ti2v_2_2", "lucy_edit"]
class family_handler():
@@ -82,6 +85,7 @@ def query_model_def(base_model_type, model_def):
extra_model_def["i2v_class"] = i2v
extra_model_def["multitalk_class"] = test_multitalk(base_model_type)
extra_model_def["standin_class"] = test_standin(base_model_type)
+ extra_model_def["lynx_class"] = test_lynx(base_model_type)
vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
extra_model_def["vace_class"] = vace_class
@@ -170,15 +174,26 @@ def query_model_def(base_model_type, model_def):
"choices":[
("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"),
("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"),
- ("Replace Person in Control Video by Person in Reference Image", "PVBAI"),
- ("Replace Person in Control Video by Person in Reference Image and Apply Relighting Process", "PVBAI1"),
+ ("Replace Person in Control Video by Person in Ref Image", "PVBAIH#"),
+ ("Replace Person in Control Video by Person in Ref Image. See Through Mask", "PVBAI#"),
],
"default": "PVBKI",
- "letters_filter": "PVBXAKI1",
+ "letters_filter": "PVBXAKIH#",
"label": "Type of Process",
+ "scale": 3,
"show_label" : False,
}
+ extra_model_def["custom_video_selection"] = {
+ "choices":[
+ ("None", ""),
+ ("Apply Relighting", "1"),
+ ],
+ "label": "Custom Process",
+ "show_label" : False,
+ "scale": 1,
+ }
+
extra_model_def["mask_preprocessing"] = {
"selection":[ "", "A", "XA"],
"visible": False
@@ -229,7 +244,7 @@ def query_model_def(base_model_type, model_def):
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["return_image_refs_tensor"] = True
- if base_model_type in ["standin"]:
+ if base_model_type in ["standin", "lynx_lite", "lynx_full"]:
extra_model_def["fit_into_canvas_image_refs"] = 0
extra_model_def["image_ref_choices"] = {
"choices": [
@@ -298,7 +313,7 @@ def query_model_def(base_model_type, model_def):
@staticmethod
def query_supported_types():
return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
- "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
+ "t2v_1.3B", "standin", "lynx_lite", "lynx_full", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B", "animate",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@@ -312,8 +327,8 @@ def query_family_maps():
}
models_comp_map = {
- "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B"],
- "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin"],
+ "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_full_14B"],
+ "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx_full"],
"i2v" : [ "fantasy", "multitalk", "flf2v_720p" ],
"i2v_2_2" : ["i2v_2_2_multitalk"],
"fantasy": ["multitalk"],
@@ -423,7 +438,7 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.31:
- if base_model_type in "recam_1.3B":
+ if base_model_type in ["recam_1.3B"]:
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if not "V" in video_prompt_type:
video_prompt_type += "UV"
@@ -438,6 +453,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if test_class_i2v(base_model_type) and len(image_prompt_type) == 0:
ui_defaults["image_prompt_type"] = "S"
+
+ if settings_version < 2.37:
+ if base_model_type in ["animate"]:
+ video_prompt_type = ui_defaults.get("video_prompt_type", "")
+ if "1" in video_prompt_type:
+ video_prompt_type = video_prompt_type.replace("1", "#1")
+ ui_defaults["video_prompt_type"] = video_prompt_type
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
@@ -481,6 +504,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"video_prompt_type": "I",
"remove_background_images_ref" : 1,
})
+ elif base_model_type in ["lynx_lite", "lynx_full"]:
+ ui_defaults.update({
+ "guidance_scale": 5.0,
+ "flow_shift": 7, # 11 for 720p
+ "sliding_window_overlap" : 9,
+ "video_prompt_type": "QI",
+ "remove_background_images_ref" : 1,
+ })
+
elif base_model_type in ["phantom_1.3B", "phantom_14B"]:
ui_defaults.update({
"guidance_scale": 7.5,
diff --git a/preprocessing/arc/face_encoder.py b/preprocessing/arc/face_encoder.py
new file mode 100644
index 000000000..0fb916219
--- /dev/null
+++ b/preprocessing/arc/face_encoder.py
@@ -0,0 +1,97 @@
+# Copyright 2025 Bytedance Ltd. and/or its affiliates
+# SPDX-License-Identifier: Apache-2.0
+
+import torch
+import torchvision.transforms as T
+
+import numpy as np
+
+from insightface.utils import face_align
+from insightface.app import FaceAnalysis
+from facexlib.recognition import init_recognition_model
+
+
+__all__ = [
+ "FaceEncoderArcFace",
+ "get_landmarks_from_image",
+]
+
+
+detector = None
+
+def get_landmarks_from_image(image):
+ """
+ Detect landmarks with insightface.
+
+ Args:
+ image (np.ndarray or PIL.Image):
+ The input image in RGB format.
+
+ Returns:
+ 5 2D keypoints, only one face will be returned.
+ """
+ global detector
+ if detector is None:
+ detector = FaceAnalysis()
+ detector.prepare(ctx_id=0, det_size=(640, 640))
+
+ in_image = np.array(image).copy()
+
+ faces = detector.get(in_image)
+ if len(faces) == 0:
+ raise ValueError("No face detected in the image")
+
+ # Get the largest face
+ face = max(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1]))
+
+ # Return the 5 keypoints directly
+ keypoints = face.kps # 5 x 2
+
+ return keypoints
+
+
+class FaceEncoderArcFace():
+ """ Official ArcFace, no_grad-only """
+
+ def __repr__(self):
+ return "ArcFace"
+
+
+ def init_encoder_model(self, device, eval_mode=True):
+ self.device = device
+ self.encoder_model = init_recognition_model('arcface', device=device)
+
+ if eval_mode:
+ self.encoder_model.eval()
+
+
+ @torch.no_grad()
+ def input_preprocessing(self, in_image, landmarks, image_size=112):
+ assert landmarks is not None, "landmarks are not provided!"
+
+ in_image = np.array(in_image)
+ landmark = np.array(landmarks)
+
+ face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=image_size)
+
+ image_transform = T.Compose([
+ T.ToTensor(),
+ T.Normalize([0.5], [0.5]),
+ ])
+ face_aligned = image_transform(face_aligned).unsqueeze(0).to(self.device)
+
+ return face_aligned
+
+
+ @torch.no_grad()
+ def __call__(self, in_image, need_proc=False, landmarks=None, image_size=112):
+
+ if need_proc:
+ in_image = self.input_preprocessing(in_image, landmarks, image_size)
+ else:
+ assert isinstance(in_image, torch.Tensor)
+
+ in_image = in_image[:, [2, 1, 0], :, :].contiguous()
+ image_embeds = self.encoder_model(in_image) # [B, 512], normalized
+
+ return image_embeds
\ No newline at end of file
diff --git a/preprocessing/arc/face_utils.py b/preprocessing/arc/face_utils.py
new file mode 100644
index 000000000..41a454546
--- /dev/null
+++ b/preprocessing/arc/face_utils.py
@@ -0,0 +1,63 @@
+# Copyright (c) 2022 Insightface Team
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+
+# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
+# SPDX-License-Identifier: Apache-2.0
+
+# Original file (insightface) was released under MIT License:
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import numpy as np
+import cv2
+from PIL import Image
+
+def estimate_norm(lmk, image_size=112, arcface_dst=None):
+ from skimage import transform as trans
+ assert lmk.shape == (5, 2)
+ assert image_size%112==0 or image_size%128==0
+ if image_size%112==0:
+ ratio = float(image_size)/112.0
+ diff_x = 0
+ else:
+ ratio = float(image_size)/128.0
+ diff_x = 8.0*ratio
+ dst = arcface_dst * ratio
+ dst[:,0] += diff_x
+ tform = trans.SimilarityTransform()
+ tform.estimate(lmk, dst)
+ M = tform.params[0:2, :]
+ return M
+
+def get_arcface_dst(extend_face_crop=False, extend_ratio=0.8):
+ arcface_dst = np.array(
+ [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366],
+ [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32)
+ if extend_face_crop:
+ arcface_dst[:,1] = arcface_dst[:,1] + 10
+ arcface_dst = (arcface_dst - 112/2) * extend_ratio + 112/2
+ return arcface_dst
+
+def align_face(image_pil, face_kpts, extend_face_crop=False, extend_ratio=0.8, face_size=112):
+ arcface_dst = get_arcface_dst(extend_face_crop, extend_ratio)
+ M = estimate_norm(face_kpts, face_size, arcface_dst)
+ image_cv2 = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
+ face_image_cv2 = cv2.warpAffine(image_cv2, M, (face_size, face_size), borderValue=0.0)
+ face_image = Image.fromarray(cv2.cvtColor(face_image_cv2, cv2.COLOR_BGR2RGB))
+ return face_image
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index d6f75f74b..e567f3ae4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -35,6 +35,8 @@ rembg[gpu]==2.0.65
onnxruntime-gpu
decord
timm
+# insightface==0.7.3
+# facexlib==0.3.0
# Config & orchestration
omegaconf
diff --git a/wgp.py b/wgp.py
index 6a257e814..9143689b8 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,8 +63,8 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.75"
-settings_version = 2.36
+WanGP_version = "8.76"
+settings_version = 2.37
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -3736,7 +3736,7 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
alpha_mask[mask > 127] = 1
frame = frame * alpha_mask
frame = Image.fromarray(frame)
- face = face_processor.process(frame, resize_to=size, face_crop_scale = 1)
+ face = face_processor.process(frame, resize_to=size, face_crop_scale = 1.2,)
face_list.append(face)
face_processor = None
@@ -3863,7 +3863,7 @@ def prep_prephase(frame_idx):
mask = op_expand(mask, kernel, iterations=3)
_, mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY)
- if to_bbox and np.sum(mask == 255) > 0:
+ if to_bbox and np.sum(mask == 255) > 0 : #or True
x0, y0, x1, y1 = mask_to_xyxy_box(mask)
mask = mask * 0
mask[y0:y1, x0:x1] = 255
@@ -4766,7 +4766,7 @@ def remove_temp_filenames(temp_filenames_list):
raise Exception(f"unknown cache type {skip_steps_cache_type}")
trans.cache = skip_steps_cache
if trans2 is not None: trans2.cache = skip_steps_cache
-
+ face_arc_embeds = None
output_new_audio_data = None
output_new_audio_filepath = None
original_audio_guide = audio_guide
@@ -5039,9 +5039,9 @@ def remove_temp_filenames(temp_filenames_list):
context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight]
if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)])
inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color
- video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size )
+ video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size, to_bbox = "H" in video_prompt_type )
if preprocess_type2 != None:
- video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size )
+ video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size, to_bbox = "H" in video_prompt_type )
if video_guide_processed is not None and sample_fit_canvas is not None:
image_size = video_guide_processed.shape[-2:]
@@ -5069,6 +5069,17 @@ def remove_temp_filenames(temp_filenames_list):
if len(image_refs) > nb_frames_positions:
src_ref_images = image_refs[nb_frames_positions:]
+ if "Q" in video_prompt_type:
+ from preprocessing.arc.face_encoder import FaceEncoderArcFace, get_landmarks_from_image
+ image_pil = src_ref_images[-1]
+ face_encoder = FaceEncoderArcFace()
+ face_encoder.init_encoder_model(processing_device)
+ face_arc_embeds = face_encoder(image_pil, need_proc=True, landmarks=get_landmarks_from_image(image_pil))
+ face_arc_embeds = face_arc_embeds.squeeze(0).cpu()
+ face_encoder = image_pil = None
+ gc.collect()
+ torch.cuda.empty_cache()
+
if remove_background_images_ref > 0:
send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")])
@@ -5079,7 +5090,6 @@ def remove_temp_filenames(temp_filenames_list):
outpainting_dims =outpainting_dims,
background_ref_outpainted = model_def.get("background_ref_outpainted", True),
return_tensor= model_def.get("return_image_refs_tensor", False) )
-
frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame]
if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False):
@@ -5241,6 +5251,7 @@ def set_header_text(txt):
original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [],
image_refs_relative_size = image_refs_relative_size,
outpainting_dims = outpainting_dims,
+ face_arc_embeds = face_arc_embeds,
)
except Exception as e:
if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0:
@@ -6980,7 +6991,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
return video_prompt_type
-all_guide_processes ="PDESLCMUVB"
+all_guide_processes ="PDESLCMUVBH"
def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
model_type = state["model_type"]
@@ -7006,7 +7017,34 @@ def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type,
else:
mask_selector_visible = True
ref_images_visible = "I" in video_prompt_type
- return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible )
+ custom_options = custom_checkbox= False
+ if "#" in video_prompt_type:
+ custom_options = True
+ custom_video_choices = model_def.get("custom_video_selection", None)
+ if custom_video_choices is not None:
+ custom_checkbox = len(custom_video_choices.get("choices", [])) <= 2
+
+ return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ), gr.update(visible= custom_options and not custom_checkbox ), gr.update(visible= custom_options and custom_checkbox )
+
+def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, video_prompt_type_video_custom_dropbox):
+ model_type = state["model_type"]
+ model_def = get_model_def(model_type)
+ custom_video_choices = model_def.get("custom_video_selection", None)
+ if not "#" in video_prompt_type or custom_video_choices is None: return gr.update()
+ video_prompt_type = del_in_sequence(video_prompt_type, "0123456789")
+ video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox)
+ return video_prompt_type
+
+def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, video_prompt_type_video_custom_checkbox):
+ model_type = state["model_type"]
+ model_def = get_model_def(model_type)
+ custom_video_choices = model_def.get("custom_video_selection", None)
+ if not "#" in video_prompt_type or custom_video_choices is None: return gr.update()
+ video_prompt_type = del_in_sequence(video_prompt_type, "0123456789")
+ if video_prompt_type_video_custom_checkbox:
+ video_prompt_type = add_to_sequence(video_prompt_type, custom_video_choices["choices"][1][1])
+ return video_prompt_type
+
# def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
@@ -7474,7 +7512,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
else:
image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False)
if "E" in image_prompt_types_allowed:
- image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1)
+ image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1, elem_classes="cbx_centered")
any_end_image = True
else:
image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1)
@@ -7539,14 +7577,14 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
video_prompt_type_video_guide = gr.Dropdown(
guide_preprocessing_choices,
value=filter_letters(video_prompt_type_value, all_guide_processes, guide_preprocessing.get("default", "") ),
- label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True,
+ label= video_prompt_type_video_guide_label , scale = 1, visible= guide_preprocessing.get("visible", True) , show_label= True,
)
any_control_video = True
any_control_image = image_outputs
# Alternate Control Video Preprocessing / Options
if guide_custom_choices is None:
- video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 2 )
+ video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 1 )
else:
video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process")
if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image")
@@ -7557,14 +7595,37 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
# value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ),
value=guide_custom_choices_value,
visible = guide_custom_choices.get("visible", True),
- label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = 2
+ label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = guide_custom_choices.get("scale", 1),
)
any_control_video = True
any_control_image = image_outputs
+ # Custom dropdown box & checkbox
+ custom_video_selection = model_def.get("custom_video_selection", None)
+ custom_checkbox= False
+ if custom_video_selection is None:
+ video_prompt_type_video_custom_dropbox = gr.Dropdown(choices=[("","")], value="", label="Custom Dropdown", scale = 1, visible= False, show_label= True, )
+ video_prompt_type_video_custom_checkbox = gr.Checkbox(value=False, label="Custom Checkbbox", scale = 1, visible= False, show_label= True, )
+
+ else:
+ custom_choices = "#" in video_prompt_type_value
+ custom_video_choices = custom_video_selection["choices"]
+ custom_checkbox = len(custom_video_choices) <= 2
+
+ video_prompt_type_video_custom_label = custom_video_selection.get("label", "Custom Choices")
+
+ video_prompt_type_video_custom_dropbox = gr.Dropdown(
+ custom_video_choices,
+ value=filter_letters(video_prompt_type_value, "0123456789", custom_video_selection.get("default", "")),
+ scale = custom_video_selection.get("scale", 1),
+ label= video_prompt_type_video_custom_label , visible= not custom_checkbox and custom_choices,
+ show_label= custom_video_selection.get("show_label", True),
+ )
+ video_prompt_type_video_custom_checkbox = gr.Checkbox(value= custom_video_choices[1][1] in video_prompt_type_value , label=custom_video_choices[1][0] , scale = custom_video_selection.get("scale", 1), visible=custom_checkbox and custom_choices, show_label= True, elem_classes="cbx_centered" )
+
# Control Mask Preprocessing
if mask_preprocessing is None:
- video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 2, visible= False, show_label= True, )
+ video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 1, visible= False, show_label= True, )
any_image_mask = image_outputs
else:
mask_preprocessing_labels_all = {
@@ -7592,7 +7653,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
video_prompt_type_video_mask = gr.Dropdown(
mask_preprocessing_choices,
value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")),
- label= video_prompt_type_video_mask_label , scale = 2, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True),
+ label= video_prompt_type_video_mask_label , scale = 1, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True),
show_label= True,
)
@@ -7604,7 +7665,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
choices=[ ("None", ""),],
value=filter_letters(video_prompt_type_value, ""),
visible = False,
- label="Start / Reference Images", scale = 2
+ label="Start / Reference Images", scale = 1
)
any_reference_image = False
else:
@@ -7613,7 +7674,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
choices= image_ref_choices["choices"],
value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]),
visible = image_ref_choices.get("visible", True),
- label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2
+ label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 1
)
image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None))
@@ -8287,7 +8348,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, image_prompt_type_group, image_prompt_type_radio, image_prompt_type_endcheckbox,
prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab,
sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col,
- video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row,
+ video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox,
+ apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row,
video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right,
video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row,
video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row,
@@ -8316,9 +8378,12 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" )
image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] )
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden")
- video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden")
- video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden")
+ video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox], show_progress="hidden")
+ video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox], show_progress="hidden")
# video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength, video_mask, mask_expand, image_mask_guide, image_guide, image_mask, keep_frames_video_guide ], show_progress="hidden")
+ video_prompt_type_video_custom_dropbox.input(fn= refresh_video_prompt_type_video_custom_dropbox, inputs=[state, video_prompt_type, video_prompt_type_video_custom_dropbox], outputs = video_prompt_type)
+ video_prompt_type_video_custom_checkbox.input(fn= refresh_video_prompt_type_video_custom_checkbox, inputs=[state, video_prompt_type, video_prompt_type_video_custom_checkbox], outputs = video_prompt_type)
+
video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden")
video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type])
multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden")
@@ -9587,6 +9652,7 @@ def create_ui():
transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* 1s delay before showing */
}
.btn_centered {margin-top:10px; text-wrap-mode: nowrap;}
+ .cbx_centered label {margin-top:8px; text-wrap-mode: nowrap;}
"""
UI_theme = server_config.get("UI_theme", "default")
UI_theme = args.theme if len(args.theme) > 0 else UI_theme
From 5ad0cd4ac65406b5b1785e06bbbcfb21f1ca579a Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 18:20:57 +0200
Subject: [PATCH 049/155] updated release notes
---
README.md | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 939cb7f04..72a7b0df8 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
-### September 25 2025: WanGP v8.73 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
+### September 25 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
@@ -29,6 +29,8 @@ In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone*
With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise)
+For those of you who have a mask halo effect when Animating a character I recommend trying *SDPA attention* and to use the *FusioniX i2v* lora. If this issue persists (this will depend on the control video) you have now a choice of the two *Animate Mask Options* in *WanGP 8.76*. The old masking option which was a WanGP exclusive has been renamed *See Through Mask* because the background behind the animated character was preserved but this creates sometime visual artifacts. The new option which has the shorter name is what you may find elsewhere online. As it uses internally a much larger mask, there is no halo. However the immediate background behind the character is not preserved and may end completely different.
+
- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
Also because I wanted to spoil you:
@@ -41,6 +43,8 @@ Also because I wanted to spoil you:
*Update 8.72*: shadow drop of Qwen Edit Plus
*Update 8.73*: Qwen Preview & InfiniteTalk Start image
*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video
+*Update 8.75*: REDACTED
+*Update 8.76*: Alternate Animate masking that fixes the mask halo effect that some users have
### September 15 2025: WanGP v8.6 - Attack of the Clones
From 4a84bdecf7db9a02f68961d7c4799bf9d77b089f Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 18:41:44 +0200
Subject: [PATCH 050/155] oops
---
models/wan/modules/model.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index 42af87737..5ed2e1724 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -1342,6 +1342,7 @@ def forward(
y = None
motion_vec_list = []
+ if face_pixel_values is None: face_pixel_values = [None] * len(x_list)
for i, (x, one_face_pixel_values) in enumerate(zip(x_list, face_pixel_values)):
# animate embeddings
motion_vec = None
From 6a913c7ecf3f904edb3bc99928fb0eb42fc460ae Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 18:49:36 +0200
Subject: [PATCH 051/155] oops
---
models/wan/modules/model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index 5ed2e1724..c2419341e 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -1347,7 +1347,7 @@ def forward(
# animate embeddings
motion_vec = None
if pose_latents is not None:
- x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(one_face_pixel_values) if one_face_pixel_values is None else one_face_pixel_values)
+ x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values)
motion_vec_list.append(motion_vec)
if chipmunk:
x = x.unsqueeze(-1)
From 6bbb60f9e8f25cc8fa28a649f6040968605c5adf Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 19:56:54 +0200
Subject: [PATCH 052/155] rollback cfg for animate
---
models/wan/modules/model.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index c2419341e..07e3e54e9 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -1347,7 +1347,8 @@ def forward(
# animate embeddings
motion_vec = None
if pose_latents is not None:
- x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values)
+ # x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values)
+ x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values[0]) # if one_face_pixel_values is None else one_face_pixel_values)
motion_vec_list.append(motion_vec)
if chipmunk:
x = x.unsqueeze(-1)
From e7433804d1e61315034fdff7ba22a6a77bf9aaa4 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 21:47:58 +0200
Subject: [PATCH 053/155] fixed infinitetalk bug
---
models/wan/any2video.py | 2 +-
models/wan/wan_handler.py | 20 ++++++++++++++------
wgp.py | 23 ++++++++++++-----------
3 files changed, 27 insertions(+), 18 deletions(-)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index b0fc24659..9f27fb30a 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -447,7 +447,7 @@ def generate(self,
if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]:
any_end_frame = False
if infinitetalk:
- new_shot = "Q" in video_prompt_type
+ new_shot = "0" in video_prompt_type
if input_frames is not None:
image_ref = input_frames[:, 0]
else:
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index e1829f03a..d3dc31e09 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -147,15 +147,15 @@ def query_model_def(base_model_type, model_def):
extra_model_def["all_image_refs_are_background_ref"] = True
extra_model_def["guide_custom_choices"] = {
"choices":[
- ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"),
+ ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "0KI"),
("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
- ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"),
+ ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "0RUV"),
("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
- ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"),
+ ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "G0UV"),
("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
],
"default": "KI",
- "letters_filter": "RGUVQKI",
+ "letters_filter": "RGUV0KI",
"label": "Video to Video",
"show_label" : False,
}
@@ -190,6 +190,7 @@ def query_model_def(base_model_type, model_def):
("Apply Relighting", "1"),
],
"label": "Custom Process",
+ "letters_filter": "1",
"show_label" : False,
"scale": 1,
}
@@ -427,7 +428,7 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
ui_defaults["audio_guidance_scale"]= 1
video_prompt_type = ui_defaults.get("video_prompt_type", "")
if "I" in video_prompt_type:
- video_prompt_type = video_prompt_type.replace("KI", "QKI")
+ video_prompt_type = video_prompt_type.replace("KI", "0KI")
ui_defaults["video_prompt_type"] = video_prompt_type
if settings_version < 2.28:
@@ -461,6 +462,13 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
video_prompt_type = video_prompt_type.replace("1", "#1")
ui_defaults["video_prompt_type"] = video_prompt_type
+ if settings_version < 2.38:
+ if base_model_type in ["infinitetalk"]:
+ video_prompt_type = ui_defaults.get("video_prompt_type", "")
+ if "Q" in video_prompt_type:
+ video_prompt_type = video_prompt_type.replace("Q", "0")
+ ui_defaults["video_prompt_type"] = video_prompt_type
+
@staticmethod
def update_default_settings(base_model_type, model_def, ui_defaults):
ui_defaults.update({
@@ -491,7 +499,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"sliding_window_overlap" : 9,
"sliding_window_size": 81,
"sample_solver" : "euler",
- "video_prompt_type": "QKI",
+ "video_prompt_type": "0KI",
"remove_background_images_ref" : 0,
"adaptive_switch" : 1,
})
diff --git a/wgp.py b/wgp.py
index 9143689b8..d03807b1f 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,8 +63,8 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.76"
-settings_version = 2.37
+WanGP_version = "8.761"
+settings_version = 2.38
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -7029,20 +7029,22 @@ def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type,
def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, video_prompt_type_video_custom_dropbox):
model_type = state["model_type"]
model_def = get_model_def(model_type)
- custom_video_choices = model_def.get("custom_video_selection", None)
- if not "#" in video_prompt_type or custom_video_choices is None: return gr.update()
- video_prompt_type = del_in_sequence(video_prompt_type, "0123456789")
+ custom_video_selection = model_def.get("custom_video_selection", None)
+ if not "#" in video_prompt_type or custom_video_selection is None: return gr.update()
+ letters_filter = custom_video_selection.get("letters_filter", "")
+ video_prompt_type = del_in_sequence(video_prompt_type, letters_filter)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox)
return video_prompt_type
def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, video_prompt_type_video_custom_checkbox):
model_type = state["model_type"]
model_def = get_model_def(model_type)
- custom_video_choices = model_def.get("custom_video_selection", None)
- if not "#" in video_prompt_type or custom_video_choices is None: return gr.update()
- video_prompt_type = del_in_sequence(video_prompt_type, "0123456789")
+ custom_video_selection = model_def.get("custom_video_selection", None)
+ if not "#" in video_prompt_type or custom_video_selection is None: return gr.update()
+ letters_filter = custom_video_selection.get("letters_filter", "")
+ video_prompt_type = del_in_sequence(video_prompt_type, letters_filter)
if video_prompt_type_video_custom_checkbox:
- video_prompt_type = add_to_sequence(video_prompt_type, custom_video_choices["choices"][1][1])
+ video_prompt_type = add_to_sequence(video_prompt_type, custom_video_selection["choices"][1][1])
return video_prompt_type
@@ -7613,10 +7615,9 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
custom_checkbox = len(custom_video_choices) <= 2
video_prompt_type_video_custom_label = custom_video_selection.get("label", "Custom Choices")
-
video_prompt_type_video_custom_dropbox = gr.Dropdown(
custom_video_choices,
- value=filter_letters(video_prompt_type_value, "0123456789", custom_video_selection.get("default", "")),
+ value=filter_letters(video_prompt_type_value, custom_video_selection.get("letters_filter", ""), custom_video_selection.get("default", "")),
scale = custom_video_selection.get("scale", 1),
label= video_prompt_type_video_custom_label , visible= not custom_checkbox and custom_choices,
show_label= custom_video_selection.get("show_label", True),
From f77ba0317f0163e925851e9f19e23da7ef6d783b Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 22:44:17 +0200
Subject: [PATCH 054/155] UI improvements
---
README.md | 2 +-
models/wan/any2video.py | 2 +-
models/wan/wan_handler.py | 27 ++++++++++++++++++++-------
wgp.py | 22 ++++++++++++----------
4 files changed, 34 insertions(+), 19 deletions(-)
diff --git a/README.md b/README.md
index 72a7b0df8..ee432d770 100644
--- a/README.md
+++ b/README.md
@@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
-### September 25 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
+### September 28 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 9f27fb30a..c9b1a29a4 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -1058,7 +1058,7 @@ def adapt_animate_model(self, model):
def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs):
if base_model_type == "animate":
- if "1" in video_prompt_type:
+ if "#" in video_prompt_type and "1" in video_prompt_type:
preloadURLs = get_model_recursive_prop(model_type, "preload_URLs")
if len(preloadURLs) > 0:
return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1]
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index d3dc31e09..b1f594271 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -147,19 +147,30 @@ def query_model_def(base_model_type, model_def):
extra_model_def["all_image_refs_are_background_ref"] = True
extra_model_def["guide_custom_choices"] = {
"choices":[
- ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "0KI"),
- ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"),
- ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "0RUV"),
- ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"),
- ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "G0UV"),
- ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"),
+ ("Images to Video, each Reference Image will start a new shot with a new Sliding Window", "KI"),
+ ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window", "RUV"),
+ ("Video to Video, amount of motion transferred depends on Denoising Strength", "GUV"),
],
"default": "KI",
- "letters_filter": "RGUV0KI",
+ "letters_filter": "RGUVKI",
"label": "Video to Video",
+ "scale": 3,
+ "show_label" : False,
+ }
+
+ extra_model_def["custom_video_selection"] = {
+ "choices":[
+ ("Smooth Transitions", ""),
+ ("Sharp Transitions", "0"),
+ ],
+ "trigger": "",
+ "label": "Custom Process",
+ "letters_filter": "0",
"show_label" : False,
+ "scale": 1,
}
+
# extra_model_def["at_least_one_image_ref_needed"] = True
if base_model_type in ["lucy_edit"]:
extra_model_def["keep_frames_video_guide_not_supported"] = True
@@ -189,7 +200,9 @@ def query_model_def(base_model_type, model_def):
("None", ""),
("Apply Relighting", "1"),
],
+ "trigger": "#",
"label": "Custom Process",
+ "type": "checkbox",
"letters_filter": "1",
"show_label" : False,
"scale": 1,
diff --git a/wgp.py b/wgp.py
index d03807b1f..e48070ca3 100644
--- a/wgp.py
+++ b/wgp.py
@@ -7017,12 +7017,13 @@ def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type,
else:
mask_selector_visible = True
ref_images_visible = "I" in video_prompt_type
- custom_options = custom_checkbox= False
- if "#" in video_prompt_type:
- custom_options = True
- custom_video_choices = model_def.get("custom_video_selection", None)
- if custom_video_choices is not None:
- custom_checkbox = len(custom_video_choices.get("choices", [])) <= 2
+ custom_options = custom_checkbox = False
+ custom_video_selection = model_def.get("custom_video_selection", None)
+ if custom_video_selection is not None:
+ custom_trigger = custom_video_selection.get("trigger","")
+ if len(custom_trigger) == 0 or custom_trigger in video_prompt_type:
+ custom_options = True
+ custom_checkbox = custom_video_selection.get("type","") == "checkbox"
return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ), gr.update(visible= custom_options and not custom_checkbox ), gr.update(visible= custom_options and custom_checkbox )
@@ -7030,7 +7031,7 @@ def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, vid
model_type = state["model_type"]
model_def = get_model_def(model_type)
custom_video_selection = model_def.get("custom_video_selection", None)
- if not "#" in video_prompt_type or custom_video_selection is None: return gr.update()
+ if custom_video_selection is None: return gr.update()
letters_filter = custom_video_selection.get("letters_filter", "")
video_prompt_type = del_in_sequence(video_prompt_type, letters_filter)
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox)
@@ -7040,7 +7041,7 @@ def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, vi
model_type = state["model_type"]
model_def = get_model_def(model_type)
custom_video_selection = model_def.get("custom_video_selection", None)
- if not "#" in video_prompt_type or custom_video_selection is None: return gr.update()
+ if custom_video_selection is None: return gr.update()
letters_filter = custom_video_selection.get("letters_filter", "")
video_prompt_type = del_in_sequence(video_prompt_type, letters_filter)
if video_prompt_type_video_custom_checkbox:
@@ -7610,9 +7611,10 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
video_prompt_type_video_custom_checkbox = gr.Checkbox(value=False, label="Custom Checkbbox", scale = 1, visible= False, show_label= True, )
else:
- custom_choices = "#" in video_prompt_type_value
custom_video_choices = custom_video_selection["choices"]
- custom_checkbox = len(custom_video_choices) <= 2
+ custom_video_trigger = custom_video_selection.get("trigger", "")
+ custom_choices = len(custom_video_trigger) == 0 or custom_video_trigger in video_prompt_type_value
+ custom_checkbox = custom_video_selection.get("type","") == "checkbox"
video_prompt_type_video_custom_label = custom_video_selection.get("label", "Custom Choices")
video_prompt_type_video_custom_dropbox = gr.Dropdown(
From a1aa82289bde197d377e76f96a421d2c76997009 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Sun, 28 Sep 2025 23:35:15 +0200
Subject: [PATCH 055/155] pinned onnx version
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index e567f3ae4..354d465a7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -32,7 +32,7 @@ loguru
opencv-python>=4.12.0.88
segment-anything
rembg[gpu]==2.0.65
-onnxruntime-gpu
+onnxruntime-gpu==1.22
decord
timm
# insightface==0.7.3
From 12e81fa84051cc9680676a378b3eeb062bf4d414 Mon Sep 17 00:00:00 2001
From: SandSeppel <65561725+SandSeppel@users.noreply.github.com>
Date: Mon, 29 Sep 2025 13:06:05 +0200
Subject: [PATCH 056/155] Added conda activate in README update
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index ee432d770..9ad477a8a 100644
--- a/README.md
+++ b/README.md
@@ -150,6 +150,7 @@ If using Pinokio use Pinokio to update otherwise:
Get in the directory where WanGP is installed and:
```bash
git pull
+conda activate wan2gp
pip install -r requirements.txt
```
From 73b22f8a04d1d4df27e398d2b9931279737b644e Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 30 Sep 2025 10:51:25 +0200
Subject: [PATCH 057/155] combinatorics
---
README.md | 9 +
configs/lynx.json | 15 +
configs/vace_lynx_14B.json | 17 ++
defaults/lynx.json | 18 ++
defaults/vace_lynx_14B.json | 9 +
models/wan/any2video.py | 98 ++++--
models/wan/lynx/attention_processor.py | 41 +++
models/wan/lynx/flash_attention.py | 60 ----
models/wan/lynx/full_attention_processor.py | 305 -------------------
models/wan/lynx/inference_utils.py | 39 ---
models/wan/lynx/light_attention_processor.py | 194 ------------
models/wan/lynx/model_utils_wan.py | 264 ----------------
models/wan/modules/model.py | 83 +++--
models/wan/wan_handler.py | 70 +++--
requirements.txt | 5 +-
shared/utils/utils.py | 2 +-
wgp.py | 56 ++--
17 files changed, 325 insertions(+), 960 deletions(-)
create mode 100644 configs/lynx.json
create mode 100644 configs/vace_lynx_14B.json
create mode 100644 defaults/lynx.json
create mode 100644 defaults/vace_lynx_14B.json
create mode 100644 models/wan/lynx/attention_processor.py
delete mode 100644 models/wan/lynx/flash_attention.py
delete mode 100644 models/wan/lynx/full_attention_processor.py
delete mode 100644 models/wan/lynx/inference_utils.py
delete mode 100644 models/wan/lynx/light_attention_processor.py
delete mode 100644 models/wan/lynx/model_utils_wan.py
diff --git a/README.md b/README.md
index ee432d770..6abdaf5c2 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,15 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
+### September 30 2025: WanGP v8.9 - Combinatorics
+
+This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps.
+
+*WanGP 8.9* also illustrate how existing WanGP features can be easily combined with new models. For instance with *Lynx* you will get out of the box *Video to Video* and *Image/Text to Image*.
+
+Another fun combination is *Vace* + *Lynx*, which works much better than *Vace StandIn*. I have added sliders to change the weight of Vace & Lynx to allow you to tune the effects.
+
+
### September 28 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
diff --git a/configs/lynx.json b/configs/lynx.json
new file mode 100644
index 000000000..a973d5e68
--- /dev/null
+++ b/configs/lynx.json
@@ -0,0 +1,15 @@
+{
+ "_class_name": "WanModel",
+ "_diffusers_version": "0.30.0",
+ "dim": 5120,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_dim": 16,
+ "model_type": "t2v",
+ "num_heads": 40,
+ "num_layers": 40,
+ "out_dim": 16,
+ "text_len": 512,
+ "lynx": "full"
+}
diff --git a/configs/vace_lynx_14B.json b/configs/vace_lynx_14B.json
new file mode 100644
index 000000000..770cbeef6
--- /dev/null
+++ b/configs/vace_lynx_14B.json
@@ -0,0 +1,17 @@
+{
+ "_class_name": "WanModel",
+ "_diffusers_version": "0.30.0",
+ "dim": 5120,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_dim": 16,
+ "model_type": "t2v",
+ "num_heads": 40,
+ "num_layers": 40,
+ "out_dim": 16,
+ "text_len": 512,
+ "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
+ "vace_in_dim": 96,
+ "lynx": "full"
+}
diff --git a/defaults/lynx.json b/defaults/lynx.json
new file mode 100644
index 000000000..4c0a17f1e
--- /dev/null
+++ b/defaults/lynx.json
@@ -0,0 +1,18 @@
+{
+ "model": {
+ "name": "Wan2.1 Lynx 14B",
+ "modules": [
+ [
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_bf16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_quanto_bf16_int8.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_quanto_fp16_int8.safetensors"
+ ]
+ ],
+ "architecture": "lynx",
+ "description": "The Lynx ControlNet offers State of the Art Identity Preservation. You need to provide a Reference Image which is a close up of a person face to transfer this person in the Video.",
+ "URLs": "t2v",
+ "preload_URLs": [
+ "wan2.1_lynx_full_arc_resampler.safetensors"
+ ]
+ }
+}
\ No newline at end of file
diff --git a/defaults/vace_lynx_14B.json b/defaults/vace_lynx_14B.json
new file mode 100644
index 000000000..ba8bfbd7b
--- /dev/null
+++ b/defaults/vace_lynx_14B.json
@@ -0,0 +1,9 @@
+{
+ "model": {
+ "name": "Vace Lynx 14B",
+ "architecture": "vace_lynx_14B",
+ "modules": [ "vace_14B", "lynx"],
+ "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Lynx version is specialized in transferring a Person Identity.",
+ "URLs": "t2v"
+ }
+}
\ No newline at end of file
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index c9b1a29a4..5beea8e74 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -286,7 +286,32 @@ def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=No
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1,2)[0]
return msk
-
+
+ def encode_reference_images(self, ref_images, ref_prompt="image of a face", any_guidance= False, tile_size = None):
+ ref_images = [convert_image_to_tensor(img).unsqueeze(1).to(device=self.device, dtype=self.dtype) for img in ref_images]
+ shape = ref_images[0].shape
+ freqs = get_rotary_pos_embed( (len(ref_images) , shape[-2] // 8, shape[-1] // 8 ))
+ # batch_ref_image: [B, C, F, H, W]
+ vae_feat = self.vae.encode(ref_images, tile_size = tile_size)
+ vae_feat = torch.cat( vae_feat, dim=1).unsqueeze(0)
+ if any_guidance:
+ vae_feat_uncond = self.vae.encode([ref_images[0] * 0], tile_size = tile_size) * len(ref_images)
+ vae_feat_uncond = torch.cat( vae_feat_uncond, dim=1).unsqueeze(0)
+ context = self.text_encoder([ref_prompt], self.device)[0].to(self.dtype)
+ context = torch.cat([context, context.new_zeros(self.model.text_len -context.size(0), context.size(1)) ]).unsqueeze(0)
+ clear_caches()
+ get_cache("lynx_ref_buffer").update({ 0: {}, 1: {} })
+ ref_buffer = self.model(
+ pipeline =self,
+ x = [vae_feat, vae_feat_uncond] if any_guidance else [vae_feat],
+ context = [context, context] if any_guidance else [context],
+ freqs= freqs,
+ t=torch.stack([torch.tensor(0, dtype=torch.float)]).to(self.device),
+ lynx_feature_extractor = True,
+ )
+ clear_caches()
+ return ref_buffer[0], (ref_buffer[1] if any_guidance else None)
+
def generate(self,
input_prompt,
input_frames= None,
@@ -355,6 +380,7 @@ def generate(self,
video_prompt_type= "",
original_input_ref_images = [],
face_arc_embeds = None,
+ control_scale_alt = 1.,
**bbargs
):
@@ -401,13 +427,15 @@ def generate(self,
# Text Encoder
if n_prompt == "":
n_prompt = self.sample_neg_prompt
- context = self.text_encoder([input_prompt], self.device)[0]
- context_null = self.text_encoder([n_prompt], self.device)[0]
- context = context.to(self.dtype)
- context_null = context_null.to(self.dtype)
text_len = self.model.text_len
- context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0)
- context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0)
+ any_guidance_at_all = guide_scale > 1 or guide2_scale > 1 and guide_phases >=2 or guide3_scale > 1 and guide_phases >=3
+ context = self.text_encoder([input_prompt], self.device)[0].to(self.dtype)
+ context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0)
+ if NAG_scale > 1 or any_guidance_at_all:
+ context_null = self.text_encoder([n_prompt], self.device)[0].to(self.dtype)
+ context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0)
+ else:
+ context_null = None
if input_video is not None: height, width = input_video.shape[-2:]
# NAG_prompt = "static, low resolution, blurry"
@@ -423,13 +451,13 @@ def generate(self,
# if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0)
if self._interrupt: return None
- vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B"]
+ vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"]
phantom = model_type in ["phantom_1.3B", "phantom_14B"]
fantasy = model_type in ["fantasy"]
multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"]
infinitetalk = model_type in ["infinitetalk"]
standin = model_type in ["standin", "vace_standin_14B"]
- lynx = model_type in ["lynx_lite", "lynx_full", "vace_lynx_lite_14B", "vace_lynx_full_14B"]
+ lynx = model_type in ["lynx_lite", "lynx", "vace_lynx_lite_14B", "vace_lynx_14B"]
recam = model_type in ["recam_1.3B"]
ti2v = model_type in ["ti2v_2_2", "lucy_edit"]
lucy_edit= model_type in ["lucy_edit"]
@@ -651,7 +679,9 @@ def generate(self,
if vace :
# vace context encode
input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)])
- input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])
+ input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)])
+ if model_type in ["vace_lynx_14B"]:
+ input_ref_images = input_ref_masks = None
input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images]
input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks]
ref_images_before = True
@@ -704,21 +734,36 @@ def generate(self,
kwargs["freqs"] = freqs
# Lynx
- if lynx:
- from .lynx.resampler import Resampler
- from accelerate import init_empty_weights
- lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"]
- with init_empty_weights():
- arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 )
- offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors"))
- arc_resampler.to(self.device)
- arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float)
- ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype)
- ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype)
- arc_resampler = None
- gc.collect()
- torch.cuda.empty_cache()
- kwargs["lynx_ip_scale"] = 1.
+ if lynx :
+ if original_input_ref_images is None or len(original_input_ref_images) == 0:
+ lynx = False
+ else:
+ from .lynx.resampler import Resampler
+ from accelerate import init_empty_weights
+ lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"]
+ ip_hidden_states = ip_hidden_states_uncond = None
+ if True:
+ with init_empty_weights():
+ arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 )
+ offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors"))
+ arc_resampler.to(self.device)
+ arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float)
+ ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype)
+ ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype)
+ arc_resampler = None
+ if not lynx_lite:
+ standin_ref_pos = -1
+ image_ref = original_input_ref_images[standin_ref_pos]
+ from preprocessing.face_preprocessor import FaceProcessor
+ face_processor = FaceProcessor()
+ lynx_ref = face_processor.process(image_ref, resize_to = 256 )
+ lynx_ref_buffer, lynx_ref_buffer_uncond = self.encode_reference_images([lynx_ref], tile_size=VAE_tile_size, any_guidance= any_guidance_at_all)
+ lynx_ref = None
+ gc.collect()
+ torch.cuda.empty_cache()
+ vace_lynx = model_type in ["vace_lynx_14B"]
+ kwargs["lynx_ip_scale"] = control_scale_alt
+ kwargs["lynx_ref_scale"] = control_scale_alt
#Standin
if standin:
@@ -881,6 +926,9 @@ def clear():
"context" : [context, context_null],
"lynx_ip_embeds": [ip_hidden_states, ip_hidden_states_uncond]
}
+ if model_type in ["lynx", "vace_lynx_14B"]:
+ gen_args["lynx_ref_buffer"] = [lynx_ref_buffer, lynx_ref_buffer_uncond]
+
elif multitalk and audio_proj != None:
if guide_scale == 1:
gen_args = {
diff --git a/models/wan/lynx/attention_processor.py b/models/wan/lynx/attention_processor.py
new file mode 100644
index 000000000..c64291d1d
--- /dev/null
+++ b/models/wan/lynx/attention_processor.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2025 The Wan Team and The HuggingFace Team.
+# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
+#
+# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
+#
+# Original file was released under Apache License 2.0, with the full license text
+# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt.
+#
+# This modified file is released under the same license.
+
+
+from typing import Optional, List
+import torch
+import torch.nn as nn
+from diffusers.models.normalization import RMSNorm
+
+def setup_lynx_attention_layers(blocks, lynx_full, dim):
+ if lynx_full:
+ lynx_cross_dim = 5120
+ lynx_layers = len(blocks)
+ else:
+ lynx_cross_dim = 2048
+ lynx_layers = 20
+ for i, block in enumerate(blocks):
+ if i < lynx_layers:
+ block.cross_attn.to_k_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full)
+ block.cross_attn.to_v_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full)
+ else:
+ block.cross_attn.to_k_ip = None
+ block.cross_attn.to_v_ip = None
+ if lynx_full:
+ block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5)
+ block.cross_attn.norm_rms_k = None
+ block.self_attn.to_k_ref = nn.Linear(dim, dim, bias=True)
+ block.self_attn.to_v_ref = nn.Linear(dim, dim, bias=True)
+ else:
+ block.cross_attn.registers = None
+ block.cross_attn.norm_rms_k = RMSNorm(dim, eps=1e-5, elementwise_affine=False)
+
+
+
diff --git a/models/wan/lynx/flash_attention.py b/models/wan/lynx/flash_attention.py
deleted file mode 100644
index 3bb807db8..000000000
--- a/models/wan/lynx/flash_attention.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
-# SPDX-License-Identifier: Apache-2.0
-
-import torch
-try:
- from flash_attn_interface import flash_attn_varlen_func # flash attn 3
-except:
- from flash_attn.flash_attn_interface import flash_attn_varlen_func # flash attn 2
-
-def flash_attention(query, key, value, q_lens, kv_lens, causal=False):
- """
- Args:
- query, key, value: [B, H, T, D_h]
- q_lens: list[int] per sequence query length
- kv_lens: list[int] per sequence key/value length
- causal: whether to use causal mask
-
- Returns:
- output: [B, H, T_q, D_h]
- """
- B, H, T_q, D_h = query.shape
- T_k = key.shape[2]
- device = query.device
-
- # Flatten: [B, H, T, D] -> [total_tokens, H, D]
- q = query.permute(0, 2, 1, 3).reshape(B * T_q, H, D_h)
- k = key.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h)
- v = value.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h)
-
- # Prepare cu_seqlens: prefix sum
- q_lens_tensor = torch.tensor(q_lens, device=device, dtype=torch.int32)
- kv_lens_tensor = torch.tensor(kv_lens, device=device, dtype=torch.int32)
-
- cu_seqlens_q = torch.zeros(len(q_lens_tensor) + 1, device=device, dtype=torch.int32)
- cu_seqlens_k = torch.zeros(len(kv_lens_tensor) + 1, device=device, dtype=torch.int32)
-
- cu_seqlens_q[1:] = torch.cumsum(q_lens_tensor, dim=0)
- cu_seqlens_k[1:] = torch.cumsum(kv_lens_tensor, dim=0)
-
- max_seqlen_q = int(q_lens_tensor.max().item())
- max_seqlen_k = int(kv_lens_tensor.max().item())
-
- # Call FlashAttention varlen kernel
- out = flash_attn_varlen_func(
- q,
- k,
- v,
- cu_seqlens_q,
- cu_seqlens_k,
- max_seqlen_q,
- max_seqlen_k,
- causal=causal
- )
- if not torch.is_tensor(out): # flash attn 3
- out = out[0]
-
- # Restore shape: [total_q, H, D_h] -> [B, H, T_q, D_h]
- out = out.view(B, T_q, H, D_h).permute(0, 2, 1, 3).contiguous()
-
- return out
\ No newline at end of file
diff --git a/models/wan/lynx/full_attention_processor.py b/models/wan/lynx/full_attention_processor.py
deleted file mode 100644
index fd0edb7a8..000000000
--- a/models/wan/lynx/full_attention_processor.py
+++ /dev/null
@@ -1,305 +0,0 @@
-# Copyright (c) 2025 The Wan Team and The HuggingFace Team.
-# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
-#
-# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
-#
-# Original file was released under Apache License 2.0, with the full license text
-# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt.
-#
-# This modified file is released under the same license.
-
-
-from typing import Optional, List
-import torch
-import torch.nn as nn
-from diffusers.models.attention_processor import Attention
-from modules.common.flash_attention import flash_attention
-from modules.common.navit_utils import vector_to_list, list_to_vector, merge_token_lists
-
-def new_forward(
- self,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- **cross_attention_kwargs,
-) -> torch.Tensor:
- return self.processor(
- self,
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- attention_mask=attention_mask,
- **cross_attention_kwargs,
- )
-Attention.forward = new_forward
-
-class WanIPAttnProcessor(nn.Module):
- def __init__(
- self,
- cross_attention_dim: int,
- dim: int,
- n_registers: int,
- bias: bool,
- ):
- super().__init__()
- self.to_k_ip = nn.Linear(cross_attention_dim, dim, bias=bias)
- self.to_v_ip = nn.Linear(cross_attention_dim, dim, bias=bias)
- if n_registers > 0:
- self.registers = nn.Parameter(torch.randn(1, n_registers, cross_attention_dim) / dim**0.5)
- else:
- self.registers = None
-
- def forward(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- ip_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- rotary_emb: Optional[torch.Tensor] = None,
- q_lens: List[int] = None,
- kv_lens: List[int] = None,
- ip_lens: List[int] = None,
- ip_scale: float = 1.0,
- ):
- encoder_hidden_states_img = None
- if attn.add_k_proj is not None:
- encoder_hidden_states_img = encoder_hidden_states[:, :257]
- encoder_hidden_states = encoder_hidden_states[:, 257:]
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- if ip_hidden_states is not None:
-
- if self.registers is not None:
- ip_hidden_states_list = vector_to_list(ip_hidden_states, ip_lens, 1)
- ip_hidden_states_list = merge_token_lists(ip_hidden_states_list, [self.registers] * len(ip_hidden_states_list), 1)
- ip_hidden_states, ip_lens = list_to_vector(ip_hidden_states_list, 1)
-
- ip_query = query
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
- if attn.norm_q is not None:
- ip_query = attn.norm_q(ip_query)
- if attn.norm_k is not None:
- ip_key = attn.norm_k(ip_key)
- ip_query = ip_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- ip_key = ip_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- ip_value = ip_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- ip_hidden_states = flash_attention(
- ip_query, ip_key, ip_value,
- q_lens=q_lens,
- kv_lens=ip_lens
- )
- ip_hidden_states = ip_hidden_states.transpose(1, 2).flatten(2, 3)
- ip_hidden_states = ip_hidden_states.type_as(ip_query)
-
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
-
- if rotary_emb is not None:
-
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
- x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
-
- query = apply_rotary_emb(query, rotary_emb)
- key = apply_rotary_emb(key, rotary_emb)
-
- hidden_states = flash_attention(
- query, key, value,
- q_lens=q_lens,
- kv_lens=kv_lens,
- )
-
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
- hidden_states = hidden_states.type_as(query)
-
- if ip_hidden_states is not None:
- hidden_states = hidden_states + ip_scale * ip_hidden_states
-
- hidden_states = attn.to_out[0](hidden_states)
- hidden_states = attn.to_out[1](hidden_states)
- return hidden_states
-
-def register_ip_adapter(
- model,
- cross_attention_dim=None,
- n_registers=0,
- init_method="zero",
- dtype=torch.float32,
-):
- attn_procs = {}
- transformer_sd = model.state_dict()
- for layer_idx, block in enumerate(model.blocks):
- name = f"blocks.{layer_idx}.attn2.processor"
- layer_name = name.split(".processor")[0]
- dim = transformer_sd[layer_name + ".to_k.weight"].shape[1]
- attn_procs[name] = WanIPAttnProcessor(
- cross_attention_dim=dim if cross_attention_dim is None else cross_attention_dim,
- dim=dim,
- n_registers=n_registers,
- bias=True,
- )
- if init_method == "zero":
- torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight)
- torch.nn.init.zeros_(attn_procs[name].to_k_ip.bias)
- torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight)
- torch.nn.init.zeros_(attn_procs[name].to_v_ip.bias)
- elif init_method == "clone":
- layer_name = name.split(".processor")[0]
- weights = {
- "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"],
- "to_k_ip.bias": transformer_sd[layer_name + ".to_k.bias"],
- "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"],
- "to_v_ip.bias": transformer_sd[layer_name + ".to_v.bias"],
- }
- attn_procs[name].load_state_dict(weights)
- elif init_method == "random":
- pass
- else:
- raise ValueError(f"{init_method} is not supported.")
- block.attn2.processor = attn_procs[name]
- ip_layers = torch.nn.ModuleList(attn_procs.values())
- ip_layers.to(device=model.device, dtype=dtype)
- return model, ip_layers
-
-
-class WanRefAttnProcessor(nn.Module):
- def __init__(
- self,
- dim: int,
- bias: bool,
- ):
- super().__init__()
- self.to_k_ref = nn.Linear(dim, dim, bias=bias)
- self.to_v_ref = nn.Linear(dim, dim, bias=bias)
-
- def forward(
- self,
- attn: Attention,
- hidden_states: torch.Tensor,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- rotary_emb: Optional[torch.Tensor] = None,
- q_lens: List[int] = None,
- kv_lens: List[int] = None,
- ref_feature: Optional[tuple] = None,
- ref_scale: float = 1.0,
- ):
- encoder_hidden_states_img = None
- if attn.add_k_proj is not None:
- encoder_hidden_states_img = encoder_hidden_states[:, :257]
- encoder_hidden_states = encoder_hidden_states[:, 257:]
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- if ref_feature is None:
- ref_hidden_states = None
- else:
- ref_hidden_states, ref_lens = ref_feature
- ref_query = query
- ref_key = self.to_k_ref(ref_hidden_states)
- ref_value = self.to_v_ref(ref_hidden_states)
- if attn.norm_q is not None:
- ref_query = attn.norm_q(ref_query)
- if attn.norm_k is not None:
- ref_key = attn.norm_k(ref_key)
- ref_query = ref_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- ref_key = ref_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- ref_value = ref_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- ref_hidden_states = flash_attention(
- ref_query, ref_key, ref_value,
- q_lens=q_lens,
- kv_lens=ref_lens
- )
- ref_hidden_states = ref_hidden_states.transpose(1, 2).flatten(2, 3)
- ref_hidden_states = ref_hidden_states.type_as(ref_query)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
-
- if rotary_emb is not None:
-
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
- x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
-
- query = apply_rotary_emb(query, rotary_emb)
- key = apply_rotary_emb(key, rotary_emb)
-
- hidden_states = flash_attention(
- query, key, value,
- q_lens=q_lens,
- kv_lens=kv_lens,
- )
-
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
- hidden_states = hidden_states.type_as(query)
-
- if ref_hidden_states is not None:
- hidden_states = hidden_states + ref_scale * ref_hidden_states
-
- hidden_states = attn.to_out[0](hidden_states)
- hidden_states = attn.to_out[1](hidden_states)
- return hidden_states
-
-
-def register_ref_adapter(
- model,
- init_method="zero",
- dtype=torch.float32,
-):
- attn_procs = {}
- transformer_sd = model.state_dict()
- for layer_idx, block in enumerate(model.blocks):
- name = f"blocks.{layer_idx}.attn1.processor"
- layer_name = name.split(".processor")[0]
- dim = transformer_sd[layer_name + ".to_k.weight"].shape[0]
- attn_procs[name] = WanRefAttnProcessor(
- dim=dim,
- bias=True,
- )
- if init_method == "zero":
- torch.nn.init.zeros_(attn_procs[name].to_k_ref.weight)
- torch.nn.init.zeros_(attn_procs[name].to_k_ref.bias)
- torch.nn.init.zeros_(attn_procs[name].to_v_ref.weight)
- torch.nn.init.zeros_(attn_procs[name].to_v_ref.bias)
- elif init_method == "clone":
- weights = {
- "to_k_ref.weight": transformer_sd[layer_name + ".to_k.weight"],
- "to_k_ref.bias": transformer_sd[layer_name + ".to_k.bias"],
- "to_v_ref.weight": transformer_sd[layer_name + ".to_v.weight"],
- "to_v_ref.bias": transformer_sd[layer_name + ".to_v.bias"],
- }
- attn_procs[name].load_state_dict(weights)
- elif init_method == "random":
- pass
- else:
- raise ValueError(f"{init_method} is not supported.")
- block.attn1.processor = attn_procs[name]
- ref_layers = torch.nn.ModuleList(attn_procs.values())
- ref_layers.to(device=model.device, dtype=dtype)
- return model, ref_layers
diff --git a/models/wan/lynx/inference_utils.py b/models/wan/lynx/inference_utils.py
deleted file mode 100644
index c1129f3bf..000000000
--- a/models/wan/lynx/inference_utils.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
-# SPDX-License-Identifier: Apache-2.0
-
-from typing import Optional, Union
-
-import torch
-import numpy as np
-
-from PIL import Image
-from dataclasses import dataclass, field
-
-
-dtype_mapping = {
- "bf16": torch.bfloat16,
- "fp16": torch.float16,
- "fp32": torch.float32
-}
-
-
-@dataclass
-class SubjectInfo:
- name: str = ""
- image_pil: Optional[Image.Image] = None
- landmarks: Optional[Union[np.ndarray, torch.Tensor]] = None
- face_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None
-
-
-@dataclass
-class VideoStyleInfo: # key names should match those used in style.yaml file
- style_name: str = 'none'
- num_frames: int = 81
- seed: int = -1
- guidance_scale: float = 5.0
- guidance_scale_i: float = 2.0
- num_inference_steps: int = 50
- width: int = 832
- height: int = 480
- prompt: str = ''
- negative_prompt: str = ''
\ No newline at end of file
diff --git a/models/wan/lynx/light_attention_processor.py b/models/wan/lynx/light_attention_processor.py
deleted file mode 100644
index a3c7516b1..000000000
--- a/models/wan/lynx/light_attention_processor.py
+++ /dev/null
@@ -1,194 +0,0 @@
-# Copyright (c) 2025 The Wan Team and The HuggingFace Team.
-# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
-#
-# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
-#
-# Original file was released under Apache License 2.0, with the full license text
-# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt.
-#
-# This modified file is released under the same license.
-
-from typing import Tuple, Union, Optional
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from diffusers.models.normalization import RMSNorm
-
-
-def register_ip_adapter_wan(
- model,
- hidden_size=5120,
- cross_attention_dim=2048,
- dtype=torch.float32,
- init_method="zero",
- layers=None
-):
- attn_procs = {}
- transformer_sd = model.state_dict()
-
- if layers is None:
- layers = list(range(0, len(model.blocks)))
- elif isinstance(layers, int): # Only interval provided
- layers = list(range(0, len(model.blocks), layers))
-
- for i, block in enumerate(model.blocks):
- if i not in layers:
- continue
-
- name = f"blocks.{i}.attn2.processor"
- attn_procs[name] = IPAWanAttnProcessor2_0(
- hidden_size=hidden_size,
- cross_attention_dim=cross_attention_dim
- )
-
- if init_method == "zero":
- torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight)
- torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight)
- elif init_method == "clone":
- layer_name = name.split(".processor")[0]
- weights = {
- "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"],
- "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"],
- }
- attn_procs[name].load_state_dict(weights)
- else:
- raise ValueError(f"{init_method} is not supported.")
-
- block.attn2.processor = attn_procs[name]
-
- ip_layers = torch.nn.ModuleList(attn_procs.values())
- ip_layers.to(model.device, dtype=dtype)
-
- return model, ip_layers
-
-
-class IPAWanAttnProcessor2_0(torch.nn.Module):
- def __init__(
- self,
- hidden_size,
- cross_attention_dim=None,
- scale=1.0,
- bias=False,
- ):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("IPAWanAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
- self.scale = scale
-
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias)
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias)
-
- torch.nn.init.zeros_(self.to_k_ip.weight)
- torch.nn.init.zeros_(self.to_v_ip.weight)
- if bias:
- torch.nn.init.zeros_(self.to_k_ip.bias)
- torch.nn.init.zeros_(self.to_v_ip.bias)
-
- self.norm_rms_k = RMSNorm(hidden_size, eps=1e-5, elementwise_affine=False)
-
-
- def __call__(
- self,
- attn,
- hidden_states: torch.Tensor,
- image_embed: torch.Tensor = None,
- ip_scale: float = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.Tensor:
- encoder_hidden_states_img = None
- if attn.add_k_proj is not None:
- encoder_hidden_states_img = encoder_hidden_states[:, :257]
- encoder_hidden_states = encoder_hidden_states[:, 257:]
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
-
- query = attn.to_q(hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
-
- # =============================================================
- batch_size = image_embed.size(0)
-
- ip_hidden_states = image_embed
- ip_query = query # attn.to_q(hidden_states.clone())
- ip_key = self.to_k_ip(ip_hidden_states)
- ip_value = self.to_v_ip(ip_hidden_states)
-
- if attn.norm_q is not None:
- ip_query = attn.norm_q(ip_query)
- ip_key = self.norm_rms_k(ip_key)
-
- ip_inner_dim = ip_key.shape[-1]
- ip_head_dim = ip_inner_dim // attn.heads
-
- ip_query = ip_query.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2)
- ip_key = ip_key.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2)
-
- ip_hidden_states = F.scaled_dot_product_attention(
- ip_query, ip_key, ip_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * ip_head_dim)
- ip_hidden_states = ip_hidden_states.to(ip_query.dtype)
- # ===========================================================================
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
-
- if rotary_emb is not None:
-
- def apply_rotary_emb_inner(hidden_states: torch.Tensor, freqs: torch.Tensor):
- x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
-
- query = apply_rotary_emb_inner(query, rotary_emb)
- key = apply_rotary_emb_inner(key, rotary_emb)
-
- # I2V task
- hidden_states_img = None
- if encoder_hidden_states_img is not None:
- key_img = attn.add_k_proj(encoder_hidden_states_img)
- key_img = attn.norm_added_k(key_img)
- value_img = attn.add_v_proj(encoder_hidden_states_img)
-
- key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
-
- hidden_states_img = F.scaled_dot_product_attention(
- query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
- hidden_states_img = hidden_states_img.type_as(query)
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
- hidden_states = hidden_states.type_as(query)
-
- if hidden_states_img is not None:
- hidden_states = hidden_states + hidden_states_img
-
- # Add IPA residual
- ip_scale = ip_scale or self.scale
- hidden_states = hidden_states + ip_scale * ip_hidden_states
-
- hidden_states = attn.to_out[0](hidden_states)
- hidden_states = attn.to_out[1](hidden_states)
-
- return hidden_states
\ No newline at end of file
diff --git a/models/wan/lynx/model_utils_wan.py b/models/wan/lynx/model_utils_wan.py
deleted file mode 100644
index 2fd117046..000000000
--- a/models/wan/lynx/model_utils_wan.py
+++ /dev/null
@@ -1,264 +0,0 @@
-# Copyright (c) 2025 The Wan Team and The HuggingFace Team.
-# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
-# SPDX-License-Identifier: Apache-2.0
-#
-# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025.
-#
-# Original file was released under Apache License 2.0, with the full license text
-# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt.
-#
-# This modified file is released under the same license.
-
-import os
-import math
-import torch
-
-from typing import List, Optional, Tuple, Union
-
-from diffusers.models.embeddings import get_3d_rotary_pos_embed
-from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
-
-from diffusers import FlowMatchEulerDiscreteScheduler
-
-""" Utility functions for Wan2.1-T2I model """
-
-
-def cal_mean_and_std(vae, device, dtype):
- # https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py#L483
-
- # latents_mean
- latents_mean = torch.tensor(vae.config.latents_mean).view(
- 1, vae.config.z_dim, 1, 1, 1).to(device, dtype)
-
- # latents_std
- latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(
- 1, vae.config.z_dim, 1, 1, 1).to(device, dtype)
-
- return latents_mean, latents_std
-
-
-def encode_video(vae, video):
- # Input video is formatted as [B, F, C, H, W]
- video = video.to(vae.device, dtype=vae.dtype)
- video = video.unsqueeze(0) if len(video.shape) == 4 else video
-
- # VAE takes [B, C, F, H, W]
- video = video.permute(0, 2, 1, 3, 4)
- latent_dist = vae.encode(video).latent_dist
-
- mean, std = cal_mean_and_std(vae, vae.device, vae.dtype)
-
- # NOTE: this implementation is weird but I still borrow it from the official
- # codes. A standard normailization should be (x - mean) / std, but the std
- # here is computed by `1 / vae.config.latents_std` from cal_mean_and_std().
-
- # Sample latent with scaling
- video_latent = (latent_dist.sample() - mean) * std
-
- # Return as [B, C, F, H, W] as required by Wan
- return video_latent
-
-
-def decode_latent(vae, latent):
- # Input latent is formatted as [B, C, F, H, W]
- latent = latent.to(vae.device, dtype=vae.dtype)
- latent = latent.unsqueeze(0) if len(latent.shape) == 4 else latent
-
- mean, std = cal_mean_and_std(vae, latent.device, latent.dtype)
- latent = latent / std + mean
-
- # VAE takes [B, C, F, H, W]
- frames = vae.decode(latent).sample
-
- # Return as [B, F, C, H, W]
- frames = frames.permute(0, 2, 1, 3, 4).float()
-
- return frames
-
-
-def encode_prompt(
- tokenizer,
- text_encoder,
- prompt: Union[str, List[str]],
- num_videos_per_prompt: int = 1,
- max_sequence_length: int = 226,
- device: Optional[torch.device] = None,
- dtype: Optional[torch.dtype] = None,
- text_input_ids=None,
- prompt_attention_mask=None,
-):
- prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
-
- if tokenizer is not None:
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=max_sequence_length,
- truncation=True,
- add_special_tokens=True,
- return_tensors="pt",
- )
- text_input_ids = text_inputs.input_ids
- prompt_attention_mask = text_inputs.attention_mask
- else:
- if text_input_ids is None:
- raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.")
-
- seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long()
- prompt_embeds = text_encoder(text_input_ids.to(device), prompt_attention_mask.to(device)).last_hidden_state
-
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
- prompt_embeds = torch.stack(
- [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
- )
-
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- _, seq_len, _ = prompt_embeds.shape
- prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
-
- return prompt_embeds, prompt_attention_mask
-
-
-def compute_prompt_embeddings(
- tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False
-):
- if requires_grad:
- prompt_embeds = encode_prompt(
- tokenizer,
- text_encoder,
- prompt,
- num_videos_per_prompt=1,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
- else:
- with torch.no_grad():
- prompt_embeds = encode_prompt(
- tokenizer,
- text_encoder,
- prompt,
- num_videos_per_prompt=1,
- max_sequence_length=max_sequence_length,
- device=device,
- dtype=dtype,
- )
- return prompt_embeds[0]
-
-
-def prepare_rotary_positional_embeddings(
- height: int,
- width: int,
- num_frames: int,
- vae_scale_factor_spatial: int = 8,
- patch_size: int = 2,
- attention_head_dim: int = 64,
- device: Optional[torch.device] = None,
- base_height: int = 480,
- base_width: int = 720,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- grid_height = height // (vae_scale_factor_spatial * patch_size)
- grid_width = width // (vae_scale_factor_spatial * patch_size)
- base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
- base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
-
- grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
- freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
- embed_dim=attention_head_dim,
- crops_coords=grid_crops_coords,
- grid_size=(grid_height, grid_width),
- temporal_size=num_frames,
- )
-
- freqs_cos = freqs_cos.to(device=device)
- freqs_sin = freqs_sin.to(device=device)
- return freqs_cos, freqs_sin
-
-
-def prepare_sigmas(
- scheduler: FlowMatchEulerDiscreteScheduler,
- sigmas: torch.Tensor,
- batch_size: int,
- num_train_timesteps: int,
- flow_weighting_scheme: str = "none",
- flow_logit_mean: float = 0.0,
- flow_logit_std: float = 1.0,
- flow_mode_scale: float = 1.29,
- device: torch.device = torch.device("cpu"),
- generator: Optional[torch.Generator] = None,
-) -> torch.Tensor:
- if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
- weights = compute_density_for_timestep_sampling(
- weighting_scheme=flow_weighting_scheme,
- batch_size=batch_size,
- logit_mean=flow_logit_mean,
- logit_std=flow_logit_std,
- mode_scale=flow_mode_scale,
- device=device,
- generator=generator,
- )
- indices = (weights * num_train_timesteps).long()
- else:
- raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
-
- return sigmas[indices]
-
-
-def compute_density_for_timestep_sampling(
- weighting_scheme: str,
- batch_size: int,
- logit_mean: float = None,
- logit_std: float = None,
- mode_scale: float = None,
- device: torch.device = torch.device("cpu"),
- generator: Optional[torch.Generator] = None,
-) -> torch.Tensor:
- r"""
- Compute the density for sampling the timesteps when doing SD3 training.
-
- Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
-
- SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
- """
- if weighting_scheme == "logit_normal":
- # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
- u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
- u = torch.nn.functional.sigmoid(u)
- elif weighting_scheme == "mode":
- u = torch.rand(size=(batch_size,), device=device, generator=generator)
- u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
- else:
- u = torch.rand(size=(batch_size,), device=device, generator=generator)
- return u
-
-
-def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor:
- assert len(tensor.shape) <= ndim
- return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape)))
-
-
-def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
- r"""Forward process of flow matching."""
- return (1.0 - t) * x0 + t * n
-
-
-def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
- r"""Loss target for flow matching."""
- return n - x0
-
-
-def prepare_loss_weights(
- scheduler: FlowMatchEulerDiscreteScheduler,
- sigmas: Optional[torch.Tensor] = None,
- flow_weighting_scheme: str = "none",
-
-) -> torch.Tensor:
-
- assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)
-
- from diffusers.training_utils import compute_loss_weighting_for_sd3
-
- return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
\ No newline at end of file
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index 07e3e54e9..30f6172f0 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -131,7 +131,7 @@ def __init__(self, dim, eps=1e-5):
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
- def forward(self, x):
+ def forward(self, x, in_place= True):
r"""
Args:
x(Tensor): Shape [B, L, C]
@@ -141,7 +141,10 @@ def forward(self, x):
y = y.mean(dim=-1, keepdim=True)
y += self.eps
y.rsqrt_()
- x *= y
+ if in_place:
+ x *= y
+ else:
+ x = x * y
x *= self.weight
return x
# return self._norm(x).type_as(x) * self.weight
@@ -274,7 +277,7 @@ def text_cross_attention(self, xlist, context, return_q = False):
else:
return x, None
- def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1):
+ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1, lynx_ref_buffer = None, lynx_ref_scale = 0, sub_x_no=0):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@@ -287,7 +290,21 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
- q, k, v = self.q(x), self.k(x), self.v(x)
+ q = self.q(x)
+ ref_hidden_states = None
+ if not lynx_ref_buffer is None:
+ lynx_ref_features = lynx_ref_buffer[self.block_no]
+ if self.norm_q is not None: ref_query = self.norm_q(q, in_place = False)
+ ref_key = self.to_k_ref(lynx_ref_features)
+ ref_value = self.to_v_ref(lynx_ref_features)
+ if self.norm_k is not None: ref_key = self.norm_k(ref_key)
+ ref_query, ref_key, ref_value = ref_query.unflatten(2, (self.num_heads, -1)), ref_key.unflatten(2, (self.num_heads, -1)), ref_value.unflatten(2, (self.num_heads, -1))
+ qkv_list = [ref_query, ref_key, ref_value ]
+ del ref_query, ref_key, ref_value
+ ref_hidden_states = pay_attention(qkv_list)
+
+ k, v = self.k(x), self.v(x)
+
if standin_phase == 1:
q += self.q_loras(x)
k += self.k_loras(x)
@@ -347,6 +364,8 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks
)
del q,k,v
+ if ref_hidden_states is not None:
+ x.add_(ref_hidden_states, alpha= lynx_ref_scale)
x = x.flatten(2)
x = self.o(x)
@@ -537,6 +556,10 @@ def forward(
motion_vec = None,
lynx_ip_embeds = None,
lynx_ip_scale = 0,
+ lynx_ref_scale = 0,
+ lynx_feature_extractor = False,
+ lynx_ref_buffer = None,
+ sub_x_no =0,
):
r"""
Args:
@@ -571,6 +594,7 @@ def forward(
x_mod *= 1 + e[1]
x_mod += e[0]
x_mod = restore_latent_shape(x_mod)
+
if cam_emb != None:
cam_emb = self.cam_encoder(cam_emb)
cam_emb = cam_emb.repeat(1, 2, 1)
@@ -579,8 +603,9 @@ def forward(
x_mod += cam_emb
xlist = [x_mod.to(attention_dtype)]
+ if lynx_feature_extractor: get_cache("lynx_ref_buffer")[sub_x_no][self.block_no] = xlist[0]
del x_mod
- y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, )
+ y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, lynx_ref_buffer = lynx_ref_buffer, lynx_ref_scale = lynx_ref_scale, sub_x_no = sub_x_no)
y = y.to(dtype)
if cam_emb != None: y = self.projector(y)
@@ -1073,30 +1098,9 @@ def __init__(self,
block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128)
if lynx is not None:
- from ..lynx.light_attention_processor import IPAWanAttnProcessor2_0
+ from ..lynx.attention_processor import setup_lynx_attention_layers
lynx_full = lynx=="full"
- if lynx_full:
- lynx_cross_dim = 5120
- lynx_layers = len(self.blocks)
- else:
- lynx_cross_dim = 2048
- lynx_layers = 20
- for i, block in enumerate(self.blocks):
- if i < lynx_layers:
- attn = IPAWanAttnProcessor2_0(hidden_size=dim, cross_attention_dim = lynx_cross_dim,)
- block.cross_attn.to_k_ip = attn.to_k_ip
- block.cross_attn.to_v_ip = attn.to_v_ip
- else:
- block.cross_attn.to_k_ip = None
- block.cross_attn.to_v_ip = None
- if lynx_full:
- block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5)
- block.cross_attn.norm_rms_k = None
- else:
- block.cross_attn.registers = None
- block.cross_attn.norm_rms_k = attn.norm_rms_k
-
-
+ setup_lynx_attention_layers(self.blocks, lynx_full, dim)
if animate:
self.pose_patch_embedding = nn.Conv3d(
@@ -1302,7 +1306,10 @@ def forward(
pose_latents=None,
face_pixel_values=None,
lynx_ip_embeds = None,
- lynx_ip_scale = 0,
+ lynx_ip_scale = 0,
+ lynx_ref_scale = 0,
+ lynx_feature_extractor = False,
+ lynx_ref_buffer = None,
):
# patch_dtype = self.patch_embedding.weight.dtype
@@ -1389,6 +1396,8 @@ def forward(
audio_context_lens=audio_context_lens,
ref_images_count=ref_images_count,
lynx_ip_scale= lynx_ip_scale,
+ lynx_ref_scale = lynx_ref_scale,
+ lynx_feature_extractor = lynx_feature_extractor,
)
_flag_df = t.dim() == 2
@@ -1413,6 +1422,12 @@ def forward(
else:
lynx_ip_embeds_list = lynx_ip_embeds
+ if lynx_ref_buffer is None:
+ lynx_ref_buffer_list = [None] * len(x_list)
+ else:
+ lynx_ref_buffer_list = lynx_ref_buffer
+
+
if self.inject_sample_info and fps!=None:
fps = torch.tensor(fps, dtype=torch.long, device=device)
@@ -1562,11 +1577,11 @@ def forward(
continue
x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs)
else:
- for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec, lynx_ip_embeds) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list, lynx_ip_embeds_list)):
+ for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec, lynx_ip_embeds,lynx_ref_buffer) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list, lynx_ip_embeds_list,lynx_ref_buffer_list)):
if should_calc:
- x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec, lynx_ip_embeds= lynx_ip_embeds, **kwargs)
+ x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec, lynx_ip_embeds= lynx_ip_embeds, lynx_ref_buffer = lynx_ref_buffer, sub_x_no =i, **kwargs)
del x
- context = hints = audio_embedding = None
+ context = hints = None
if skips_steps_cache != None:
if joint_pass:
@@ -1590,7 +1605,9 @@ def forward(
torch.sub(x_list[0], ori_hidden_states[0], out=residual)
skips_steps_cache.previous_residual[x_id] = residual
residual, ori_hidden_states = None, None
-
+ if lynx_feature_extractor:
+ return get_cache("lynx_ref_buffer")
+
for i, x in enumerate(x_list):
if chipmunk:
x = reverse_voxel_chunk_no_padding(x.transpose(1, 2).unsqueeze(-1), x_og_shape, voxel_shape).squeeze(-1)
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index b1f594271..95a43ee18 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -18,7 +18,7 @@ def test_standin(base_model_type):
return base_model_type in ["standin", "vace_standin_14B"]
def test_lynx(base_model_type):
- return base_model_type in ["lynx_lite", "vace_lynx_lite_14B", "lynx_full", "vace_lynx_full_14B"]
+ return base_model_type in ["lynx_lite", "vace_lynx_lite_14B", "lynx", "vace_lynx_14B"]
def test_wan_5B(base_model_type):
return base_model_type in ["ti2v_2_2", "lucy_edit"]
@@ -86,7 +86,7 @@ def query_model_def(base_model_type, model_def):
extra_model_def["multitalk_class"] = test_multitalk(base_model_type)
extra_model_def["standin_class"] = test_standin(base_model_type)
extra_model_def["lynx_class"] = test_lynx(base_model_type)
- vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"]
+ vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"]
extra_model_def["vace_class"] = vace_class
if base_model_type in ["animate"]:
@@ -140,6 +140,7 @@ def query_model_def(base_model_type, model_def):
"selection":[ "", "A"],
"visible": False
}
+ extra_model_def["v2i_switch_supported"] = True
if base_model_type in ["infinitetalk"]:
@@ -213,15 +214,6 @@ def query_model_def(base_model_type, model_def):
"visible": False
}
- # extra_model_def["image_ref_choices"] = {
- # "choices": [("None", ""),
- # ("People / Objects", "I"),
- # ("Landscape followed by People / Objects (if any)", "KI"),
- # ],
- # "visible": False,
- # "letters_filter": "KFI",
- # }
-
extra_model_def["video_guide_outpainting"] = [0,1]
extra_model_def["keep_frames_video_guide_not_supported"] = True
extra_model_def["extract_guide_from_window_start"] = True
@@ -257,17 +249,51 @@ def query_model_def(base_model_type, model_def):
extra_model_def["guide_inpaint_color"] = 127.5
extra_model_def["forced_guide_mask_inputs"] = True
extra_model_def["return_image_refs_tensor"] = True
+ extra_model_def["v2i_switch_supported"] = True
+ if base_model_type in ["vace_lynx_14B"]:
+ extra_model_def["set_video_prompt_type"]="Q"
+ extra_model_def["control_net_weight_alt_name"] = "Lynx"
+ extra_model_def["image_ref_choices"] = { "choices": [("None", ""),("Person Face", "I")], "letters_filter": "I"}
+ extra_model_def["no_background_removal"] = True
+ extra_model_def["fit_into_canvas_image_refs"] = 0
+
- if base_model_type in ["standin", "lynx_lite", "lynx_full"]:
+
+ if base_model_type in ["standin"]:
+ extra_model_def["v2i_switch_supported"] = True
+
+ if base_model_type in ["lynx_lite", "lynx"]:
extra_model_def["fit_into_canvas_image_refs"] = 0
+ extra_model_def["guide_custom_choices"] = {
+ "choices":[("Use Reference Image which is a Person Face", ""),
+ ("Video to Video guided by Text Prompt & Reference Image", "GUV"),
+ ("Video to Video on the Area of the Video Mask", "GVA")],
+ "default": "",
+ "letters_filter": "GUVA",
+ "label": "Video to Video",
+ "show_label" : False,
+ }
+
+ extra_model_def["mask_preprocessing"] = {
+ "selection":[ "", "A"],
+ "visible": False
+ }
+
extra_model_def["image_ref_choices"] = {
"choices": [
("No Reference Image", ""),
("Reference Image is a Person Face", "I"),
],
+ "visible": False,
"letters_filter":"I",
}
+ extra_model_def["set_video_prompt_type"]= "Q"
+ extra_model_def["no_background_removal"] = True
+ extra_model_def["v2i_switch_supported"] = True
+ extra_model_def["control_net_weight_alt_name"] = "Lynx"
+
+
if base_model_type in ["phantom_1.3B", "phantom_14B"]:
extra_model_def["image_ref_choices"] = {
"choices": [("Reference Image", "I")],
@@ -326,8 +352,8 @@ def query_model_def(base_model_type, model_def):
@staticmethod
def query_supported_types():
- return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B",
- "t2v_1.3B", "standin", "lynx_lite", "lynx_full", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
+ return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B",
+ "t2v_1.3B", "standin", "lynx_lite", "lynx", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B",
"recam_1.3B", "animate",
"i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"]
@@ -338,11 +364,13 @@ def query_family_maps():
models_eqv_map = {
"flf2v_720p" : "i2v",
"t2v_1.3B" : "t2v",
+ "vace_standin_14B" : "vace_14B",
+ "vace_lynx_14B" : "vace_14B",
}
models_comp_map = {
- "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_full_14B"],
- "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx_full"],
+ "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B"],
+ "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx"],
"i2v" : [ "fantasy", "multitalk", "flf2v_720p" ],
"i2v_2_2" : ["i2v_2_2_multitalk"],
"fantasy": ["multitalk"],
@@ -523,15 +551,17 @@ def update_default_settings(base_model_type, model_def, ui_defaults):
"flow_shift": 7, # 11 for 720p
"sliding_window_overlap" : 9,
"video_prompt_type": "I",
- "remove_background_images_ref" : 1,
+ "remove_background_images_ref" : 1 ,
})
- elif base_model_type in ["lynx_lite", "lynx_full"]:
+
+ elif base_model_type in ["lynx_lite", "lynx"]:
ui_defaults.update({
"guidance_scale": 5.0,
"flow_shift": 7, # 11 for 720p
"sliding_window_overlap" : 9,
- "video_prompt_type": "QI",
- "remove_background_images_ref" : 1,
+ "video_prompt_type": "I",
+ "denoising_strength": 0.8,
+ "remove_background_images_ref" : 0,
})
elif base_model_type in ["phantom_1.3B", "phantom_14B"]:
diff --git a/requirements.txt b/requirements.txt
index 354d465a7..9b2c6eeb6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -35,8 +35,9 @@ rembg[gpu]==2.0.65
onnxruntime-gpu==1.22
decord
timm
-# insightface==0.7.3
-# facexlib==0.3.0
+insightface @ https://github.com/deepbeepmeep/insightface/raw/refs/heads/master/wheels/insightface-0.7.3-cp310-cp310-win_amd64.whl ; sys_platform == "win32" and python_version == "3.10"
+insightface==0.7.3 ; sys_platform == "linux"
+facexlib==0.3.0
# Config & orchestration
omegaconf
diff --git a/shared/utils/utils.py b/shared/utils/utils.py
index 00178aad1..a03927e62 100644
--- a/shared/utils/utils.py
+++ b/shared/utils/utils.py
@@ -205,7 +205,7 @@ def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_d
frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100)
return frame_height, frame_width
-def rgb_bw_to_rgba_mask(img, thresh=127):
+def rgb_bw_to_rgba_mask(img, thresh=127):
a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha
out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent
out.putalpha(a) # white where alpha=255
diff --git a/wgp.py b/wgp.py
index e48070ca3..74223a6c9 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.761"
+WanGP_version = "8.9"
settings_version = 2.38
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -3492,7 +3492,23 @@ def check(src, cond):
if video_apg_switch is not None and video_apg_switch != 0:
values += ["on"]
labels += ["APG"]
-
+
+ control_net_weight = ""
+ if test_vace_module(video_model_type):
+ video_control_net_weight = configs.get("control_net_weight", 1)
+ if len(filter_letters(video_video_prompt_type, video_guide_processes))> 1:
+ video_control_net_weight2 = configs.get("control_net_weight2", 1)
+ control_net_weight = f"Vace #1={video_control_net_weight}, Vace #2={video_control_net_weight2}"
+ else:
+ control_net_weight = f"Vace={video_control_net_weight}"
+ control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "")
+ if len(control_net_weight_alt_name) >0:
+ if len(control_net_weight): control_net_weight += ", "
+ control_net_weight += control_net_weight_alt_name + "=" + str(configs.get("control_net_alt", 1))
+ if len(control_net_weight) > 0:
+ values += [control_net_weight]
+ labels += ["Control Net Weights"]
+
video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "")
video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0)
video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0)
@@ -3736,7 +3752,7 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr
alpha_mask[mask > 127] = 1
frame = frame * alpha_mask
frame = Image.fromarray(frame)
- face = face_processor.process(frame, resize_to=size, face_crop_scale = 1.2,)
+ face = face_processor.process(frame, resize_to=size)
face_list.append(face)
face_processor = None
@@ -4513,6 +4529,7 @@ def generate_video(
image_mask,
control_net_weight,
control_net_weight2,
+ control_net_weight_alt,
mask_expand,
audio_guide,
audio_guide2,
@@ -4577,6 +4594,9 @@ def remove_temp_filenames(temp_filenames_list):
model_def = get_model_def(model_type)
is_image = image_mode > 0
+ set_video_prompt_type = model_def.get("set_video_prompt_type", None)
+ if set_video_prompt_type is not None:
+ video_prompt_type = add_to_sequence(video_prompt_type, set_video_prompt_type)
if is_image:
if min_frames_if_references >= 1000:
video_length = min_frames_if_references - 1000
@@ -4728,7 +4748,7 @@ def remove_temp_filenames(temp_filenames_list):
_, _, latent_size = get_model_min_frames_and_step(model_type)
original_image_refs = image_refs
- image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified
+ image_refs = None if image_refs is None else ([] + image_refs) # work on a copy as it is going to be modified
# image_refs = None
# nb_frames_positions= 0
# Output Video Ratio Priorities:
@@ -5025,7 +5045,7 @@ def remove_temp_filenames(temp_filenames_list):
# Generic Video Preprocessing
process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None)
preprocess_type, preprocess_type2 = "raw", None
- for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")):
+ for process_num, process_letter in enumerate( filter_letters(video_prompt_type, video_guide_processes)):
if process_num == 0:
preprocess_type = process_map_video_guide.get(process_letter, "raw")
else:
@@ -5224,6 +5244,7 @@ def set_header_text(txt):
audio_scale= audio_scale,
audio_context_lens= audio_context_lens,
context_scale = context_scale,
+ control_scale_alt = control_net_weight_alt,
model_mode = model_mode,
causal_block_size = 5,
causal_attention = True,
@@ -6231,7 +6252,10 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None
if not vace:
pop += ["frames_positions", "control_net_weight", "control_net_weight2"]
-
+
+ if not len(model_def.get("control_net_weight_alt_name", "")) >0:
+ pop += ["control_net_alt"]
+
if model_def.get("video_guide_outpainting", None) is None:
pop += ["video_guide_outpainting"]
@@ -6679,6 +6703,7 @@ def save_inputs(
image_mask,
control_net_weight,
control_net_weight2,
+ control_net_weight_alt,
mask_expand,
audio_guide,
audio_guide2,
@@ -6992,6 +7017,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t
return video_prompt_type
all_guide_processes ="PDESLCMUVBH"
+video_guide_processes = "PEDSLCMU"
def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ):
model_type = state["model_type"]
@@ -7440,13 +7466,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
phantom = base_model_type in ["phantom_1.3B", "phantom_14B"]
fantasy = base_model_type in ["fantasy"]
multitalk = model_def.get("multitalk_class", False)
- standin = model_def.get("standin_class", False)
infinitetalk = base_model_type in ["infinitetalk"]
hunyuan_t2v = "hunyuan_video_720" in model_filename
hunyuan_i2v = "hunyuan_video_i2v" in model_filename
- hunyuan_video_custom = "hunyuan_video_custom" in model_filename
- hunyuan_video_custom = base_model_type in ["hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit"]
- hunyuan_video_custom_audio = base_model_type in ["hunyuan_custom_audio"]
hunyuan_video_custom_edit = base_model_type in ["hunyuan_custom_edit"]
hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename
flux = base_model_family in ["flux"]
@@ -7460,7 +7482,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
image_prompt_type_value = ""
video_prompt_type_value = ""
any_start_image = any_end_image = any_reference_image = any_image_mask = False
- v2i_switch_supported = (vace or t2v or standin) and not image_outputs
+ v2i_switch_supported = model_def.get("v2i_switch_supported", False) and not image_outputs
ti2v_2_2 = base_model_type in ["ti2v_2_2"]
gallery_height = 350
def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ):
@@ -7539,7 +7561,6 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
guide_custom_choices = model_def.get("guide_custom_choices", None)
image_ref_choices = model_def.get("image_ref_choices", None)
- # with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or (flux or qwen ) and model_reference_image and image_mode_value >=1) as video_prompt_column:
with gr.Column(visible= guide_preprocessing is not None or mask_preprocessing is not None or guide_custom_choices is not None or image_ref_choices is not None) as video_prompt_column:
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
@@ -7923,10 +7944,11 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl
choices= sample_solver_choices, visible= True, label= "Sampler Solver / Scheduler"
)
flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs)
-
- with gr.Row(visible = vace) as control_net_weights_row:
- control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace)
- control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace)
+ control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "")
+ with gr.Row(visible = vace or len(control_net_weight_alt_name) >0 ) as control_net_weights_row:
+ control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Vace Weight #1", visible=vace)
+ control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Vace Weight #2", visible=vace)
+ control_net_weight_alt = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label=control_net_weight_alt_name + " Weight", visible=len(control_net_weight_alt_name) >0)
negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = not (hunyuan_t2v or hunyuan_i2v or no_negative_prompt) )
with gr.Column(visible = vace or t2v or test_class_i2v(model_type)) as NAG_col:
gr.Markdown("NAG enforces Negative Prompt even if no Guidance is set (CFG = 1), set NAG Scale to > 1 to enable it")
@@ -8113,7 +8135,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai
with gr.Row():
cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = any_cfg_zero)
- with gr.Column(visible = (vace or t2v or standin) and image_outputs) as min_frames_if_references_col:
+ with gr.Column(visible = v2i_switch_supported and image_outputs) as min_frames_if_references_col:
gr.Markdown("Generating a single Frame alone may not be sufficient to preserve Reference Image Identity / Control Image Information or simply to get a good Image Quality. A workaround is to generate a short Video and keep the First Frame.")
min_frames_if_references = gr.Dropdown(
choices=[
From 69ce2860b8fb137411a04496d26e4879d30592c3 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Tue, 30 Sep 2025 10:59:04 +0200
Subject: [PATCH 058/155] updated lynx defaults
---
defaults/lynx.json | 2 +-
defaults/vace_lynx_14B.json | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/defaults/lynx.json b/defaults/lynx.json
index 4c0a17f1e..528f5ef68 100644
--- a/defaults/lynx.json
+++ b/defaults/lynx.json
@@ -12,7 +12,7 @@
"description": "The Lynx ControlNet offers State of the Art Identity Preservation. You need to provide a Reference Image which is a close up of a person face to transfer this person in the Video.",
"URLs": "t2v",
"preload_URLs": [
- "wan2.1_lynx_full_arc_resampler.safetensors"
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_arc_resampler.safetensors"
]
}
}
\ No newline at end of file
diff --git a/defaults/vace_lynx_14B.json b/defaults/vace_lynx_14B.json
index ba8bfbd7b..523d5b3fc 100644
--- a/defaults/vace_lynx_14B.json
+++ b/defaults/vace_lynx_14B.json
@@ -4,6 +4,7 @@
"architecture": "vace_lynx_14B",
"modules": [ "vace_14B", "lynx"],
"description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Lynx version is specialized in transferring a Person Identity.",
- "URLs": "t2v"
+ "URLs": "t2v",
+ "preload_URLs": "lynx"
}
}
\ No newline at end of file
From 1c27d496ae6246f716b04e4e53df55590cfe7969 Mon Sep 17 00:00:00 2001
From: Gunther Schulz
Date: Tue, 30 Sep 2025 17:21:00 +0200
Subject: [PATCH 059/155] Add LCM + LTX sampler for Lightning LoRA
compatibility
- Implement LCMScheduler with RectifiedFlow (LTX) dynamics
- Combine Latent Consistency Model with rectified flow scheduling
- Optimize for 2-8 step ultra-fast inference with Lightning LoRAs
- Add proper flow matching dynamics with shift parameter support
- Update UI to show 'lcm + ltx' option in sampler dropdown
---
.gitignore | 4 ++
models/wan/any2video.py | 13 ++++-
models/wan/wan_handler.py | 3 +-
shared/utils/lcm_scheduler.py | 99 +++++++++++++++++++++++++++++++++++
4 files changed, 117 insertions(+), 2 deletions(-)
create mode 100644 shared/utils/lcm_scheduler.py
diff --git a/.gitignore b/.gitignore
index a84122c61..8049651d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,3 +41,7 @@ gradio_outputs/
ckpts/
loras/
loras_i2v/
+
+settings/
+
+wgp_config.json
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 5beea8e74..5cf1e712a 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -32,11 +32,11 @@
from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed
from shared.utils.vace_preprocessor import VaceVideoProcessor
from shared.utils.basic_flowmatch import FlowMatchScheduler
+from shared.utils.lcm_scheduler import LCMScheduler
from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas
from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask
from shared.utils.audio_video import save_video
from mmgp import safetensors2
-from shared.utils.audio_video import save_video
def optimized_scale(positive_flat, negative_flat):
@@ -413,6 +413,17 @@ def generate(self,
sample_scheduler,
device=self.device,
sigmas=sampling_sigmas)
+ elif sample_solver == 'lcm':
+ # LCM + LTX scheduler: Latent Consistency Model with RectifiedFlow
+ # Optimized for Lightning LoRAs with ultra-fast 2-8 step inference
+ effective_steps = min(sampling_steps, 8) # LCM works best with few steps
+ sample_scheduler = LCMScheduler(
+ num_train_timesteps=self.num_train_timesteps,
+ num_inference_steps=effective_steps,
+ shift=shift
+ )
+ sample_scheduler.set_timesteps(effective_steps, device=self.device, shift=shift)
+ timesteps = sample_scheduler.timesteps
else:
raise NotImplementedError(f"Unsupported Scheduler {sample_solver}")
original_timesteps = timesteps
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index 95a43ee18..de1212913 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -122,7 +122,8 @@ def query_model_def(base_model_type, model_def):
("unipc", "unipc"),
("euler", "euler"),
("dpm++", "dpm++"),
- ("flowmatch causvid", "causvid"), ]
+ ("flowmatch causvid", "causvid"),
+ ("lcm + ltx", "lcm"), ]
})
diff --git a/shared/utils/lcm_scheduler.py b/shared/utils/lcm_scheduler.py
new file mode 100644
index 000000000..0fa1d22a3
--- /dev/null
+++ b/shared/utils/lcm_scheduler.py
@@ -0,0 +1,99 @@
+"""
+LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow (LTX).
+Optimized for Lightning LoRA compatibility and ultra-fast inference.
+"""
+
+import torch
+import math
+from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
+
+
+class LCMScheduler(SchedulerMixin):
+ """
+ LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow.
+ - LCM: Enables 2-8 step inference with consistency models
+ - LTX: Uses RectifiedFlow for better flow matching dynamics
+ Optimized for Lightning LoRAs and ultra-fast, high-quality generation.
+ """
+
+ def __init__(self, num_train_timesteps: int = 1000, num_inference_steps: int = 4, shift: float = 1.0):
+ self.num_train_timesteps = num_train_timesteps
+ self.num_inference_steps = num_inference_steps
+ self.shift = shift
+ self._step_index = None
+
+ def set_timesteps(self, num_inference_steps: int, device=None, shift: float = None, **kwargs):
+ """Set timesteps for LCM+LTX inference using RectifiedFlow approach"""
+ self.num_inference_steps = min(num_inference_steps, 8) # LCM works best with 2-8 steps
+
+ if shift is None:
+ shift = self.shift
+
+ # RectifiedFlow (LTX) approach: Use rectified flow dynamics for better sampling
+ # This creates a more optimal path through the probability flow ODE
+ t = torch.linspace(0, 1, self.num_inference_steps + 1, dtype=torch.float32)
+
+ # Apply rectified flow transformation for better dynamics
+ # This is the key LTX component - rectified flow scheduling
+ sigma_max = 1.0
+ sigma_min = 0.003 / 1.002
+
+ # Rectified flow uses a more sophisticated sigma schedule
+ # that accounts for the flow matching dynamics
+ sigmas = sigma_min + (sigma_max - sigma_min) * (1 - t)
+
+ # Apply shift for flow matching (similar to other flow-based schedulers)
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
+
+ self.sigmas = sigmas
+ self.timesteps = self.sigmas[:-1] * self.num_train_timesteps
+
+ if device is not None:
+ self.timesteps = self.timesteps.to(device)
+ self.sigmas = self.sigmas.to(device)
+ self._step_index = None
+
+ def step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor, **kwargs) -> SchedulerOutput:
+ """
+ Perform LCM + LTX step combining consistency model with rectified flow.
+ - LCM: Direct consistency model prediction for fast inference
+ - LTX: RectifiedFlow dynamics for optimal probability flow path
+ """
+ if self._step_index is None:
+ self._init_step_index(timestep)
+
+ # Get current and next sigma values from RectifiedFlow schedule
+ sigma = self.sigmas[self._step_index]
+ if self._step_index + 1 < len(self.sigmas):
+ sigma_next = self.sigmas[self._step_index + 1]
+ else:
+ sigma_next = torch.zeros_like(sigma)
+
+ # LCM + LTX: Combine consistency model approach with rectified flow dynamics
+ # The model_output represents the velocity field in the rectified flow ODE
+ # LCM allows us to take larger steps while maintaining consistency
+
+ # RectifiedFlow step: x_{t+1} = x_t + v_θ(x_t, t) * (σ_next - σ)
+ # This is the core flow matching equation with LTX rectified dynamics
+ sigma_diff = (sigma_next - sigma)
+ while len(sigma_diff.shape) < len(sample.shape):
+ sigma_diff = sigma_diff.unsqueeze(-1)
+
+ # LCM consistency: The model is trained to be consistent across timesteps
+ # allowing for fewer steps while maintaining quality
+ prev_sample = sample + model_output * sigma_diff
+ self._step_index += 1
+
+ return SchedulerOutput(prev_sample=prev_sample)
+
+ def _init_step_index(self, timestep):
+ """Initialize step index based on current timestep"""
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+ indices = (self.timesteps == timestep).nonzero()
+ if len(indices) > 0:
+ self._step_index = indices[0].item()
+ else:
+ # Find closest timestep if exact match not found
+ diffs = torch.abs(self.timesteps - timestep)
+ self._step_index = torch.argmin(diffs).item()
From 7f0abc94d2d5f94019ce28d4a8582647c942c13f Mon Sep 17 00:00:00 2001
From: Gunther Schulz
Date: Wed, 1 Oct 2025 20:36:50 +0200
Subject: [PATCH 060/155] Add human-readable time format to generation time
display
- Added format_generation_time() function to display generation time with human-readable format
- Shows raw seconds with minutes/hours in parentheses when over 60 seconds
- Format: '335s (5m 35s)' for times over 1 minute, '7384s (2h 3m 4s)' for times over 1 hour
- Improves UX by making long generation times easier to understand at a glance
---
wgp.py | 20 +++++++++++++++++++-
1 file changed, 19 insertions(+), 1 deletion(-)
diff --git a/wgp.py b/wgp.py
index 74223a6c9..5ec0583a5 100644
--- a/wgp.py
+++ b/wgp.py
@@ -148,6 +148,24 @@ def format_time(seconds):
else:
return f"{seconds:.1f}s"
+def format_generation_time(seconds):
+ """Format generation time showing raw seconds with human-readable time in parentheses when over 60s"""
+ raw_seconds = f"{int(seconds)}s"
+
+ if seconds < 60:
+ return raw_seconds
+
+ hours = int(seconds // 3600)
+ minutes = int((seconds % 3600) // 60)
+ secs = int(seconds % 60)
+
+ if hours > 0:
+ human_readable = f"{hours}h {minutes}m {secs}s"
+ else:
+ human_readable = f"{minutes}m {secs}s"
+
+ return f"{raw_seconds} ({human_readable})"
+
def pil_to_base64_uri(pil_image, format="png", quality=75):
if pil_image is None:
return None
@@ -3459,7 +3477,7 @@ def check(src, cond):
video_num_inference_steps = configs.get("num_inference_steps", 0)
video_creation_date = str(get_file_creation_date(file_name))
if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")]
- video_generation_time = str(configs.get("generation_time", "0")) + "s"
+ video_generation_time = format_generation_time(float(configs.get("generation_time", "0")))
video_activated_loras = configs.get("activated_loras", [])
video_loras_multipliers = configs.get("loras_multipliers", "")
video_loras_multipliers = preparse_loras_multipliers(video_loras_multipliers)
From fbf99df32eefed58d66fd810dd944f23738290a8 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 2 Oct 2025 19:29:45 +0200
Subject: [PATCH 061/155] Lora Reloaded
---
README.md | 15 +
defaults/animate.json | 1 +
defaults/hunyuan_t2v_fast.json | 1 +
defaults/i2v_fusionix.json | 1 +
defaults/lucy_edit.json | 1 +
defaults/lucy_edit_fastwan.json | 1 +
defaults/ti2v_2_2.json | 1 +
defaults/ti2v_2_2_fastwan.json | 1 +
defaults/vace_14B_cocktail.json | 1 +
defaults/vace_14B_cocktail_2_2.json | 1 +
defaults/vace_14B_fusionix.json | 1 +
docs/LORAS.md | 64 ++-
loras/wan/Text2Video Cocktail - 10 Steps.json | 13 +
loras/wan/Text2Video FusioniX - 10 Steps.json | 11 +
.../Lightning v1.0 2 Phases - 4 Steps.json | 14 +
.../Lightning v1.0 3 Phases - 8 Steps.json | 18 +
.../Lightning v2025-09-28 - 4 Steps.json | 14 +
.../Text2Video FusioniX - 10 Steps.json | 11 +
loras/wan2_2_5B/FastWan 3 Steps.json | 7 +
.../wan/Image2Video FusioniX - 10 Steps.json | 10 +
loras_qwen/qwen/Lightning v1.0 - 4 Steps.json | 8 +
loras_qwen/qwen/Lightning v1.1 - 8 Steps.json | 8 +
models/wan/wan_handler.py | 2 +-
shared/attention.py | 1 +
wgp.py | 385 ++++++++++++------
25 files changed, 458 insertions(+), 133 deletions(-)
create mode 100644 loras/wan/Text2Video Cocktail - 10 Steps.json
create mode 100644 loras/wan/Text2Video FusioniX - 10 Steps.json
create mode 100644 loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json
create mode 100644 loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json
create mode 100644 loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json
create mode 100644 loras/wan2_2/Text2Video FusioniX - 10 Steps.json
create mode 100644 loras/wan2_2_5B/FastWan 3 Steps.json
create mode 100644 loras_i2v/wan/Image2Video FusioniX - 10 Steps.json
create mode 100644 loras_qwen/qwen/Lightning v1.0 - 4 Steps.json
create mode 100644 loras_qwen/qwen/Lightning v1.1 - 8 Steps.json
diff --git a/README.md b/README.md
index 6abdaf5c2..637b1640d 100644
--- a/README.md
+++ b/README.md
@@ -20,6 +20,21 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models
**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
## 🔥 Latest Updates :
+### October 2 2025: WanGP v8.99 - One Last Thing before the Big Unknown ...
+
+This new version hasn't any new model...
+
+...but temptation to upgrade will be high as it contains a few Loras related features that may change your Life:
+- **Ready to use Loras Accelerators Profiles** per type of model that you can apply on your current *Generation Settings*. Next time I will recommend a *Lora Accelerator*, it will be only one click away. And best of all of the required Loras will be downloaded automatically. When you apply an *Accelerator Profile*, input fields like the *Number of Denoising Steps* *Activated Loras*, *Loras Multipliers* (such as "1;0 0;1" ...) will be automatically filled. However your video specific fields will be preserved, so it will be easy to switch between Profiles to experiment.
+
+- **Embedded Loras URL** : WanGP will now try to remember every Lora URLs it sees. For instance if someone sends you some settings that contain Loras URLs or you extract the Settings of Video generated by a friend with Loras URLs, these URLs will be automatically added to *WanGP URL Cache*. Conversely everything you will share (Videos, Settings, Lset files) will contain the download URLs if they are known. You can also download directly a Lora in WanGP by using the *Download Lora* button a the bottom. The Lora will be immediatly available and added to WanGP lora URL cache. This will work with *Hugging Face* as a repository. Support for CivitAi will come as soon as someone will nice enough to post a GitHub PR ...
+
+- **.lset file** supports embedded Loras URLs. It has never been easier to share a Lora with a friend. As a reminder a .lset file can be created directly from *WanGP Web Interface* and it contains a list of Loras and their multipliers, a Prompt and Instructions how to use these loras (like the Lora's *Trigger*). So with embedded Loras URL, you can send an .lset file by email or share it on discord: it is just a 1 KB tiny text, but with it other people will be able to use Gigabytes Loras as these will be automatically downloaded.
+
+I have created the new Discord Channel **share-your-settings** where you can post your *Settings* or *Lset files*. I will be pleased to add new Loras Accelerators in the list of WanGP *Accelerators Profiles if you post some good ones there.
+
+Last but not least the Lora's documentation has been updated.
+
### September 30 2025: WanGP v8.9 - Combinatorics
This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps.
diff --git a/defaults/animate.json b/defaults/animate.json
index bf45f4a92..d4676721a 100644
--- a/defaults/animate.json
+++ b/defaults/animate.json
@@ -12,6 +12,7 @@
[
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors"
],
+ "settings_dir": [ "wan" ],
"group": "wan2_2"
}
}
\ No newline at end of file
diff --git a/defaults/hunyuan_t2v_fast.json b/defaults/hunyuan_t2v_fast.json
index acba28e20..30adb17a0 100644
--- a/defaults/hunyuan_t2v_fast.json
+++ b/defaults/hunyuan_t2v_fast.json
@@ -3,6 +3,7 @@
"name": "Hunyuan Video FastHunyuan 720p 13B",
"architecture": "hunyuan",
"description": "Fast Hunyuan is an accelerated HunyuanVideo model. It can sample high quality videos with 6 diffusion steps.",
+ "settings_dir": [ "" ],
"URLs": [
"https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/fast_hunyuan_video_720_quanto_int8.safetensors"
],
diff --git a/defaults/i2v_fusionix.json b/defaults/i2v_fusionix.json
index 851d6cc7a..8b0a8af54 100644
--- a/defaults/i2v_fusionix.json
+++ b/defaults/i2v_fusionix.json
@@ -5,6 +5,7 @@
"architecture" : "i2v",
"description": "A powerful merged image-to-video model based on the original WAN 2.1 I2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
"URLs": "i2v",
+ "settings_dir": [ "" ],
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"]
}
}
\ No newline at end of file
diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json
index a8f67ad64..57d3d958a 100644
--- a/defaults/lucy_edit.json
+++ b/defaults/lucy_edit.json
@@ -8,6 +8,7 @@
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors"
],
+ "settings_dir": "ti2v_2_2",
"group": "wan2_2"
},
"prompt": "change the clothes to red",
diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json
index d5d47c814..2a52166fc 100644
--- a/defaults/lucy_edit_fastwan.json
+++ b/defaults/lucy_edit_fastwan.json
@@ -5,6 +5,7 @@
"description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.",
"URLs": "lucy_edit",
"group": "wan2_2",
+ "settings_dir": [ "" ],
"loras": "ti2v_2_2_fastwan"
},
"prompt": "change the clothes to red",
diff --git a/defaults/ti2v_2_2.json b/defaults/ti2v_2_2.json
index ac329fa2d..91e906347 100644
--- a/defaults/ti2v_2_2.json
+++ b/defaults/ti2v_2_2.json
@@ -7,6 +7,7 @@
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_mbf16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_quanto_mbf16_int8.safetensors"
],
+ "settings_dir": [ "wan2_2_5B" ],
"group": "wan2_2"
},
"video_length": 121,
diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json
index fa69f82ec..c142dcbd2 100644
--- a/defaults/ti2v_2_2_fastwan.json
+++ b/defaults/ti2v_2_2_fastwan.json
@@ -4,6 +4,7 @@
"architecture": "ti2v_2_2",
"description": "FastWan2.2-TI2V-5B-Full-Diffusers is built upon Wan-AI/Wan2.2-TI2V-5B-Diffusers. It supports efficient 3-step inference and produces high-quality videos at 121×704×1280 resolution",
"URLs": "ti2v_2_2",
+ "settings_dir": [ "" ],
"loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
"group": "wan2_2"
},
diff --git a/defaults/vace_14B_cocktail.json b/defaults/vace_14B_cocktail.json
index 87f2b78ad..0ab7a625e 100644
--- a/defaults/vace_14B_cocktail.json
+++ b/defaults/vace_14B_cocktail.json
@@ -7,6 +7,7 @@
],
"description": "This model has been created on the fly using the Wan text 2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.",
"URLs": "t2v",
+ "settings_dir": [ "" ],
"loras": [
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
diff --git a/defaults/vace_14B_cocktail_2_2.json b/defaults/vace_14B_cocktail_2_2.json
index f821a9e50..6a831d214 100644
--- a/defaults/vace_14B_cocktail_2_2.json
+++ b/defaults/vace_14B_cocktail_2_2.json
@@ -14,6 +14,7 @@
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
"https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
],
+ "settings_dir": [ "" ],
"loras_multipliers": [1, 0.2, 0.5, 0.5],
"group": "wan2_2"
},
diff --git a/defaults/vace_14B_fusionix.json b/defaults/vace_14B_fusionix.json
index 44c048c90..1c6fa7288 100644
--- a/defaults/vace_14B_fusionix.json
+++ b/defaults/vace_14B_fusionix.json
@@ -6,6 +6,7 @@
"vace_14B"
],
"description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.",
+ "settings_dir": [ "" ],
"URLs": "t2v_fusionix"
},
"negative_prompt": "",
diff --git a/docs/LORAS.md b/docs/LORAS.md
index 89b2e5932..dc0caaa56 100644
--- a/docs/LORAS.md
+++ b/docs/LORAS.md
@@ -47,6 +47,20 @@ python wgp.py --lora-dir-hunyuan /path/to/hunyuan/loras --lora-dir-ltxv /path/to
If you store loras in the loras folder once WanGP has been launched, click the *Refresh* button at the top so that it can become selectable.
+### Autodownload of Loras
+WanGP will try to remember where a Lora was obtained and will store the corresponding Download URL in the Generation settings that are embedded in the Generated Video. This is useful to share this information or to easily recover lost loras after a reinstall.
+
+This works very well if the Loras are stored in repositories such *Hugging Face* but won't work for the moment for Loras that requires a Login (like *CivitAi*) to be downloaded.
+
+WanGP will update its internal URL Lora Cache whenener one of this events will occur:
+- when applying or importing a *Accelerator Profile*, *Settings* or *Lset* file that contains Loras with full URLs (not just local paths)
+- when extracting the settings of a Video that was generated with Loras and that contained the full Loras URLs
+- when downloading manually a Lora using the *Download Lora* button at the bottom
+
+So the more you use WanGP the more the URL cache File will get updated. The file is *loras_url_cache.json* and is located in the root folder of WanGP.
+
+You can delete this file with any risk if needed or share it with friends to save them time locating the Loras. You will need to restart WanGP if you modify manually this file or delete it.
+
### Lora Multipliers
Multipliers control the strength of each lora's effect:
@@ -117,38 +131,49 @@ Loras Multipliers: 0;1;0 0;0;1
Here during the first phase with guidance 3.5, the High model will be used but there won't be any lora at all. Then during phase 2 only the High lora will be used (which requires to set the guidance to 1). At last in phase 3 WanGP will switch to the Low model and then only the Low lora will be used.
*Note that the syntax for multipliers can also be used in a Finetune model definition file (except that each multiplier definition is a string in a json list)*
-## Lora Presets
+## Lora Presets (.Lset file)
+
+Lora Presets contains all the information needed to use a Lora or a combination of Loras:
+- The full download URLs of the Loras
+- Default Loras Multipliers
+- Sample Prompt to use the Loras with their corresponding *Trigger Words* (usually as comments)
+Optionaly they may contain advanced prompts with macros to generate automatically Prompt using keywords.
-Lora Presets are combinations of loras with predefined multipliers and prompts.
+A Lora Preset is a text file of only of few kilobytes and can be easily shared between users. Don't hesitate to use this format if you have created a Lora.
### Creating Presets
1. Configure your loras and multipliers
-2. Write a prompt with comments (lines starting with #)
-3. Save as a preset with `.lset` extension
+2. Write a prompt with comments lines starting with # that contains instructions
+3. Save as a preset with `.lset` extension by clicking the *Save* button at the top, select *Save Only Loras & Full Prompt* and finally click *Go Ahead Save it!*
-### Example Preset
+### Example Lora Preset Prompt
```
# Use the keyword "ohnvx" to trigger the lora
A ohnvx character is driving a car through the city
```
-### Using Presets
-```bash
-# Load preset on startup
-python wgp.py --lora-preset mypreset.lset
+Using a macro (check the doc below), the user will just have to enter two words and the Prompt will be generated for him:
```
+! {Person}="man" : {Object}="car"
+This {Person} is cleaning his {Object}.
+```
+
-### Managing Presets
+### Managing Loras Presets (.lset Files)
- Edit, save, or delete presets directly from the web interface
- Presets include comments with usage instructions
-- Share `.lset` files with other users
+- Share `.lset` files with other users (make sure the full Loras URLs are in it)
+
+A *.lset* file may contain only local paths to the Loras if WanGP doesn't know where you got it. You can edit the .lset file with a text editor and replace the local path with its URL. If you store your Lora in Hugging Face, you can easily obtain its URL by selecting the file and clicking *Copy Download Link*.
+
+To share a *.Lset* file you will need (for the moment) to get it directly in the Lora folder where it is stored.
## Supported Formats
-WanGP supports multiple lora formats:
+WanGP supports multiple most lora formats:
- **Safetensors** (.safetensors)
- **Replicate** format
-- **Standard PyTorch** (.pt, .pth)
+- ...
## Loras Accelerators
@@ -166,8 +191,17 @@ https://huggingface.co/DeepBeepMeep/Qwen_image/tree/main/loras_accelerators
### Setup Instructions
-1. Download the Lora
-2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora
+There are three ways to setup Loras accelerators:
+1) **Finetune with Embedded Loras Accelerators**
+Some models Finetunes such as *Vace FusioniX* or *Vace Coctail* have the Loras Accelerators already set up in their own definition and you won't have to do anything as they will be downloaded with the Finetune.
+
+2) **Accelerators Profiles**
+Predefined *Accelerator Profiles* can be selected using the *Settings* Dropdown box at the top. The choices of Accelerators will depend on the models. No accelerator will be offered if the finetune / model is already accelerated. Just click *Apply* and the Accelerators Loras will be setup in the Loras tab at the bottom. Any missing Lora will be downloaded automatically the first time you try to generate a Video. Be aware that when applying an *Accelerator Profile*, inputs such as *Activated Loras*, *Number of Inference Steps*, ... will bo overwritten wheras other inputs will be preserved so that you can easily switch between profiles.
+
+3) **Manual Install**
+- Download the Lora
+- Place it in the Lora Directory of the correspondig model
+- Configure the Loras Multipliers, CFG as described in the later sections
## FusioniX (or FusionX) Lora for Wan 2.1 / Wan 2.2
If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v
diff --git a/loras/wan/Text2Video Cocktail - 10 Steps.json b/loras/wan/Text2Video Cocktail - 10 Steps.json
new file mode 100644
index 000000000..087e49c92
--- /dev/null
+++ b/loras/wan/Text2Video Cocktail - 10 Steps.json
@@ -0,0 +1,13 @@
+{
+ "activated_loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors"
+ ],
+ "loras_multipliers": "1 0.5 0.5 0.5",
+ "num_inference_steps": 10,
+ "guidance_phases": 1,
+ "guidance_scale": 1,
+ "flow_shift": 2
+}
diff --git a/loras/wan/Text2Video FusioniX - 10 Steps.json b/loras/wan/Text2Video FusioniX - 10 Steps.json
new file mode 100644
index 000000000..9fabf09eb
--- /dev/null
+++ b/loras/wan/Text2Video FusioniX - 10 Steps.json
@@ -0,0 +1,11 @@
+{
+
+ "activated_loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
+ ],
+ "loras_multipliers": "1",
+ "num_inference_steps": 10,
+ "guidance_phases": 1,
+ "guidance_scale": 1,
+ "flow_shift": 2
+}
diff --git a/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json b/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json
new file mode 100644
index 000000000..e7aeedad9
--- /dev/null
+++ b/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json
@@ -0,0 +1,14 @@
+{
+ "num_inference_steps": 4,
+ "guidance_scale": 1,
+ "guidance2_scale": 1,
+ "model_switch_phase": 1,
+ "switch_threshold": 876,
+ "guidance_phases": 2,
+ "flow_shift": 3,
+ "loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors"
+ ],
+ "loras_multipliers": "1;0 0;1"
+}
\ No newline at end of file
diff --git a/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json b/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json
new file mode 100644
index 000000000..681cf3077
--- /dev/null
+++ b/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json
@@ -0,0 +1,18 @@
+{
+ "num_inference_steps": 8,
+ "guidance_scale": 3.5,
+ "guidance2_scale": 1,
+ "guidance3_scale": 1,
+ "switch_threshold": 965,
+ "switch_threshold2": 800,
+ "guidance_phases": 3,
+ "model_switch_phase": 2,
+ "flow_shift": 3,
+ "sample_solver": "euler",
+ "loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors"
+ ],
+ "loras_multipliers": "0;1;0 0;0;1",
+ "help": "This finetune uses the Lightning 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators."
+}
\ No newline at end of file
diff --git a/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json b/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json
new file mode 100644
index 000000000..56790c89f
--- /dev/null
+++ b/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json
@@ -0,0 +1,14 @@
+{
+ "num_inference_steps": 4,
+ "guidance_scale": 1,
+ "guidance2_scale": 1,
+ "switch_threshold": 876,
+ "model_switch_phase": 1,
+ "guidance_phases": 2,
+ "flow_shift": 3,
+ "loras_multipliers": "1;0 0;1",
+ "activated_loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors",
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors"
+ ]
+}
\ No newline at end of file
diff --git a/loras/wan2_2/Text2Video FusioniX - 10 Steps.json b/loras/wan2_2/Text2Video FusioniX - 10 Steps.json
new file mode 100644
index 000000000..3695bec63
--- /dev/null
+++ b/loras/wan2_2/Text2Video FusioniX - 10 Steps.json
@@ -0,0 +1,11 @@
+{
+
+ "activated_loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors"
+ ],
+ "loras_multipliers": "1",
+ "num_inference_steps": 10,
+ "guidance_scale": 1,
+ "guidance2_scale": 1,
+ "flow_shift": 2
+}
diff --git a/loras/wan2_2_5B/FastWan 3 Steps.json b/loras/wan2_2_5B/FastWan 3 Steps.json
new file mode 100644
index 000000000..1e0383c61
--- /dev/null
+++ b/loras/wan2_2_5B/FastWan 3 Steps.json
@@ -0,0 +1,7 @@
+{
+ "activated_loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"],
+ "guidance_scale": 1,
+ "flow_shift": 3,
+ "num_inference_steps": 3,
+ "loras_multipliers": "1"
+}
\ No newline at end of file
diff --git a/loras_i2v/wan/Image2Video FusioniX - 10 Steps.json b/loras_i2v/wan/Image2Video FusioniX - 10 Steps.json
new file mode 100644
index 000000000..1fe4c14e5
--- /dev/null
+++ b/loras_i2v/wan/Image2Video FusioniX - 10 Steps.json
@@ -0,0 +1,10 @@
+{
+
+ "activated_loras": [
+ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
+ ],
+ "loras_multipliers": "1",
+ "num_inference_steps": 10,
+ "guidance_phases": 1,
+ "guidance_scale": 1
+}
diff --git a/loras_qwen/qwen/Lightning v1.0 - 4 Steps.json b/loras_qwen/qwen/Lightning v1.0 - 4 Steps.json
new file mode 100644
index 000000000..3dd322006
--- /dev/null
+++ b/loras_qwen/qwen/Lightning v1.0 - 4 Steps.json
@@ -0,0 +1,8 @@
+{
+ "activated_loras": [
+ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/loras_accelerators/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors"
+ ],
+ "loras_multipliers": "1",
+ "num_inference_steps": 4,
+ "guidance_scale": 1
+}
diff --git a/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json b/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json
new file mode 100644
index 000000000..42ae20d50
--- /dev/null
+++ b/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json
@@ -0,0 +1,8 @@
+{
+ "activated_loras": [
+ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/loras_accelerators/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors"
+ ],
+ "loras_multipliers": "1",
+ "num_inference_steps": 4,
+ "guidance_scale": 1
+}
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index de1212913..c8f845d07 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -493,7 +493,7 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults):
if settings_version < 2.32:
image_prompt_type = ui_defaults.get("image_prompt_type", "")
- if test_class_i2v(base_model_type) and len(image_prompt_type) == 0:
+ if test_class_i2v(base_model_type) and len(image_prompt_type) == 0 and "S" in model_def.get("image_prompt_types_allowed",""):
ui_defaults["image_prompt_type"] = "S"
diff --git a/shared/attention.py b/shared/attention.py
index c31087c24..da860a9b6 100644
--- a/shared/attention.py
+++ b/shared/attention.py
@@ -68,6 +68,7 @@ def sageattn2_wrapper(
return o
try:
+ # from sageattn3 import sageattn3_blackwell as sageattn3 #word0 windows version
from sageattn import sageattn_blackwell as sageattn3
except ImportError:
sageattn3 = None
diff --git a/wgp.py b/wgp.py
index 74223a6c9..9c8309be0 100644
--- a/wgp.py
+++ b/wgp.py
@@ -63,7 +63,7 @@
PROMPT_VARS_MAX = 10
target_mmgp_version = "3.6.0"
-WanGP_version = "8.9"
+WanGP_version = "8.99"
settings_version = 2.38
max_source_video_frames = 3000
prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None
@@ -360,6 +360,12 @@ def ret():
if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"):
gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting")
+ if len(activated_loras) > 0:
+ error = check_loras_exist(model_type, activated_loras)
+ if len(error) > 0:
+ gr.Info(error)
+ return ret()
+
if len(loras_multipliers) > 0:
_, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases)
if len(errors) > 0:
@@ -2089,10 +2095,10 @@ def get_transformer_dtype(model_family, transformer_dtype_policy):
def get_settings_file_name(model_type):
return os.path.join(args.settings, model_type + "_settings.json")
-def fix_settings(model_type, ui_defaults):
+def fix_settings(model_type, ui_defaults, min_settings_version = 0):
if model_type is None: return
- settings_version = ui_defaults.get("settings_version", 0)
+ settings_version = max(min_settings_version, ui_defaults.get("settings_version", 0))
model_def = get_model_def(model_type)
base_model_type = get_base_model_type(model_type)
@@ -2184,7 +2190,7 @@ def fix_settings(model_type, ui_defaults):
model_handler = get_model_handler(base_model_type)
if hasattr(model_handler, "fix_settings"):
- model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults)
+ model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults)
def get_default_settings(model_type):
@@ -2525,6 +2531,27 @@ def download_mmaudio():
}
process_files_def(**enhancer_def)
+
+def download_file(url,filename):
+ if url.startswith("https://huggingface.co/") and "/resolve/main/" in url:
+ base_dir = os.path.dirname(filename)
+ url = url[len("https://huggingface.co/"):]
+ url_parts = url.split("/resolve/main/")
+ repoId = url_parts[0]
+ onefile = os.path.basename(url_parts[-1])
+ sourceFolder = os.path.dirname(url_parts[-1])
+ if len(sourceFolder) == 0:
+ hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir)
+ else:
+ target_path = "ckpts/temp/" + sourceFolder
+ if not os.path.exists(target_path):
+ os.makedirs(target_path)
+ hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder)
+ shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir)
+ shutil.rmtree("ckpts/temp")
+ else:
+ urlretrieve(url,filename, create_progress_hook(filename))
+
download_shared_done = False
def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1):
def computeList(filename):
@@ -2574,26 +2601,6 @@ def computeList(filename):
if model_filename is None: return
- def download_file(url,filename):
- if url.startswith("https://huggingface.co/") and "/resolve/main/" in url:
- base_dir = os.path.dirname(filename)
- url = url[len("https://huggingface.co/"):]
- url_parts = url.split("/resolve/main/")
- repoId = url_parts[0]
- onefile = os.path.basename(url_parts[-1])
- sourceFolder = os.path.dirname(url_parts[-1])
- if len(sourceFolder) == 0:
- hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir)
- else:
- target_path = "ckpts/temp/" + sourceFolder
- if not os.path.exists(target_path):
- os.makedirs(target_path)
- hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder)
- shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir)
- shutil.rmtree("ckpts/temp")
- else:
- urlretrieve(url,filename, create_progress_hook(filename))
-
base_model_type = get_base_model_type(model_type)
model_def = get_model_def(model_type)
@@ -2653,6 +2660,33 @@ def download_file(url,filename):
def sanitize_file_name(file_name, rep =""):
return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep).replace("\n",rep).replace("\r",rep)
+def check_loras_exist(model_type, loras_choices_files, download = False, send_cmd = None):
+ lora_dir = get_lora_dir(model_type)
+ missing_local_loras = []
+ missing_remote_loras = []
+ for lora_file in loras_choices_files:
+ local_path = os.path.join(lora_dir, os.path.basename(lora_file))
+ if not os.path.isfile(local_path):
+ url = loras_url_cache.get(local_path, None)
+ if url is not None:
+ if download:
+ if send_cmd is not None:
+ send_cmd("status", f'Downloading Lora {os.path.basename(lora_file)}...')
+ try:
+ download_file(url, local_path)
+ except:
+ missing_remote_loras.append(lora_file)
+ else:
+ missing_local_loras.append(lora_file)
+
+ error = ""
+ if len(missing_local_loras) > 0:
+ error += f"The following Loras files are missing or invalid: {missing_local_loras}."
+ if len(missing_remote_loras) > 0:
+ error += f"The following Loras files could not be downloaded: {missing_remote_loras}."
+
+ return error
+
def extract_preset(model_type, lset_name, loras):
loras_choices = []
loras_choices_files = []
@@ -2669,24 +2703,12 @@ def extract_preset(model_type, lset_name, loras):
if not os.path.isfile(lset_name_filename):
error = f"Preset '{lset_name}' not found "
else:
- missing_loras = []
with open(lset_name_filename, "r", encoding="utf-8") as reader:
text = reader.read()
lset = json.loads(text)
- loras_choices_files = lset["loras"]
- for lora_file in loras_choices_files:
- choice = os.path.join(lora_dir, lora_file)
- if choice not in loras:
- missing_loras.append(lora_file)
- else:
- loras_choice_no = loras.index(choice)
- loras_choices.append(str(loras_choice_no))
-
- if len(missing_loras) > 0:
- error = f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing or invalid: {missing_loras}"
-
+ loras_choices = lset["loras"]
loras_mult_choices = lset["loras_mult"]
prompt = lset.get("prompt", "")
full_prompt = lset.get("full_prompt", False)
@@ -2695,7 +2717,6 @@ def extract_preset(model_type, lset_name, loras):
def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None):
loras =[]
- loras_names = []
default_loras_choices = []
default_loras_multis_str = ""
loras_presets = []
@@ -2726,7 +2747,7 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl
loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=get_loras_preprocessor(transformer, model_type), split_linear_modules_map = split_linear_modules_map) #lora_multiplier,
if len(loras) > 0:
- loras_names = [ Path(lora).stem for lora in loras ]
+ loras = [ os.path.basename(lora) for lora in loras ]
if len(lora_preselected_preset) > 0:
if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")):
@@ -2735,7 +2756,7 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl
default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(model_type, default_lora_preset, loras)
if len(error) > 0:
print(error[:200])
- return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
+ return loras, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset
def get_transformer_model(model, submodel_no = 1):
if submodel_no > 1:
@@ -2996,7 +3017,7 @@ def apply_changes( state,
return "
Config Locked
",*[gr.update()]*4
if gen_in_progress:
return "
Unable to change config when a generation is in progress
From 4033106e458836f6215882b83fcbb8da3c1bb9a8 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 10 Oct 2025 04:46:18 +1100
Subject: [PATCH 089/155] Update README.md
---
README.md | 9 +++++----
1 file changed, 5 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 7a5a7f788..6464a18f0 100644
--- a/README.md
+++ b/README.md
@@ -37,10 +37,11 @@ python main.py
```
## Screenshots
-
-
-
-
+
+
+
+
+
From d239ceb2272517779628aa8eacc075a7e52accbc Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Thu, 9 Oct 2025 22:01:00 +0200
Subject: [PATCH 090/155] fixed extract settings crash
---
shared/utils/loras_mutipliers.py | 2 +-
wgp.py | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py
index 5e024a976..dda805627 100644
--- a/shared/utils/loras_mutipliers.py
+++ b/shared/utils/loras_mutipliers.py
@@ -9,7 +9,7 @@ def preparse_loras_multipliers(loras_multipliers):
loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n")
loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")]
loras_multipliers = " ".join(loras_mult_choices_list)
- return loras_multipliers.replace("|"," ").split(" ")
+ return loras_multipliers.replace("|"," ").strip().split(" ")
def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ):
def expand_one(slist, num_inference_steps):
diff --git a/wgp.py b/wgp.py
index c0b498260..3bc3e54d2 100644
--- a/wgp.py
+++ b/wgp.py
@@ -6636,13 +6636,13 @@ def use_video_settings(state, input_file_list, choice):
else:
gr.Info(f"Settings Loaded from Video with prompt '{prompt[:100]}'")
if models_compatible:
- return gr.update(), gr.update(), str(time.time())
+ return gr.update(), gr.update(), gr.update(), str(time.time())
else:
return *generate_dropdown_model_list(model_type), gr.update()
else:
gr.Info(f"No Video is Selected")
- return gr.update(), gr.update()
+ return gr.update(), gr.update(), gr.update(), gr.update()
loras_url_cache = None
def update_loras_url_cache(lora_dir, loras_selected):
if loras_selected is None: return None
From e3ad47e7a39674ae5f77c2be88cbd609d3949767 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 10 Oct 2025 07:59:47 +1100
Subject: [PATCH 091/155] add undo/redo, delete clips, project media window,
recent projects. fixed ai plugin stalling when no server present
---
videoeditor/main.py | 842 +++++++++++++++-----
videoeditor/plugins/ai_frame_joiner/main.py | 5 +-
videoeditor/undo.py | 114 +++
3 files changed, 779 insertions(+), 182 deletions(-)
create mode 100644 videoeditor/undo.py
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 717628c42..a3f155e09 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -5,45 +5,18 @@
import re
import json
import ffmpeg
+import copy
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QHBoxLayout, QPushButton, QFileDialog, QLabel,
QScrollArea, QFrame, QProgressBar, QDialog,
- QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget)
+ QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox)
from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction,
- QPixmap, QImage)
+ QPixmap, QImage, QDrag)
from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread,
- pyqtSignal, QTimer, QByteArray)
+ pyqtSignal, QTimer, QByteArray, QMimeData)
from plugins import PluginManager, ManagePluginsDialog
-import requests
-import zipfile
-from tqdm import tqdm
-
-def download_ffmpeg():
- if os.name != 'nt': return
- exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe']
- if all(os.path.exists(e) for e in exes): return
- api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest'
- r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'})
- assets = r.json().get('assets', [])
- zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None)
- if not zip_asset: return
- zip_url = zip_asset['browser_download_url']
- zip_name = zip_asset['name']
- with requests.get(zip_url, stream=True) as resp:
- total = int(resp.headers.get('Content-Length', 0))
- with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar:
- for chunk in resp.iter_content(chunk_size=8192):
- f.write(chunk)
- pbar.update(len(chunk))
- with zipfile.ZipFile(zip_name) as z:
- for f in z.namelist():
- if f.endswith(tuple(exes)) and '/bin/' in f:
- z.extract(f)
- os.rename(f, os.path.basename(f))
- os.remove(zip_name)
-
-download_ffmpeg()
+from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand
class TimelineClip:
def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, group_id):
@@ -111,15 +84,19 @@ class TimelineWidget(QWidget):
AUDIO_TRACKS_SEPARATOR_Y = 15
split_requested = pyqtSignal(object)
+ delete_clip_requested = pyqtSignal(object)
playhead_moved = pyqtSignal(float)
split_region_requested = pyqtSignal(list)
split_all_regions_requested = pyqtSignal(list)
join_region_requested = pyqtSignal(list)
join_all_regions_requested = pyqtSignal(list)
+ delete_region_requested = pyqtSignal(list)
+ delete_all_regions_requested = pyqtSignal(list)
add_track = pyqtSignal(str)
remove_track = pyqtSignal(str)
operation_finished = pyqtSignal()
- context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent')
+ context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent')
+ clips_moved = pyqtSignal(list)
def __init__(self, timeline_model, settings, parent=None):
super().__init__(parent)
@@ -128,6 +105,7 @@ def __init__(self, timeline_model, settings, parent=None):
self.playhead_pos_sec = 0.0
self.setMinimumHeight(300)
self.setMouseTracking(True)
+ self.setAcceptDrops(True)
self.selection_regions = []
self.dragging_clip = None
self.dragging_linked_clip = None
@@ -135,7 +113,7 @@ def __init__(self, timeline_model, settings, parent=None):
self.creating_selection_region = False
self.dragging_selection_region = None
self.drag_start_pos = QPoint()
- self.drag_original_timeline_start = 0
+ self.drag_original_clip_states = {} # Store {'clip_id': (start_sec, track_index)}
self.selection_drag_start_sec = 0.0
self.drag_selection_start_values = None
@@ -147,6 +125,11 @@ def __init__(self, timeline_model, settings, parent=None):
self.video_tracks_y_start = 0
self.audio_tracks_y_start = 0
+
+ self.drag_over_active = False
+ self.drag_over_rect = QRectF()
+ self.drag_over_audio_rect = QRectF()
+ self.drag_has_audio = False
def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.PIXELS_PER_SECOND)
def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.PIXELS_PER_SECOND if x > self.HEADER_WIDTH else 0.0
@@ -160,6 +143,16 @@ def paintEvent(self, event):
self.draw_timescale(painter)
self.draw_tracks_and_clips(painter)
self.draw_selections(painter)
+
+ if self.drag_over_active:
+ painter.fillRect(self.drag_over_rect, QColor(0, 255, 0, 80))
+ if self.drag_has_audio:
+ painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80))
+ painter.setPen(QColor(0, 255, 0, 150))
+ painter.drawRect(self.drag_over_rect)
+ if self.drag_has_audio:
+ painter.drawRect(self.drag_over_audio_rect)
+
self.draw_playhead(painter)
total_width = self.sec_to_x(self.timeline.get_total_duration()) + 100
@@ -368,14 +361,20 @@ def mousePressEvent(self, event: QMouseEvent):
self.dragging_playhead = False
self.dragging_selection_region = None
self.creating_selection_region = False
+ self.drag_original_clip_states.clear()
for clip in reversed(self.timeline.clips):
clip_rect = self.get_clip_rect(clip)
if clip_rect.contains(QPointF(event.pos())):
self.dragging_clip = clip
+ self.drag_original_clip_states[clip.id] = (clip.timeline_start_sec, clip.track_index)
+
self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id), None)
+ if self.dragging_linked_clip:
+ self.drag_original_clip_states[self.dragging_linked_clip.id] = \
+ (self.dragging_linked_clip.timeline_start_sec, self.dragging_linked_clip.track_index)
+
self.drag_start_pos = event.pos()
- self.drag_original_timeline_start = clip.timeline_start_sec
break
if not self.dragging_clip:
@@ -429,25 +428,23 @@ def mouseMoveEvent(self, event: QMouseEvent):
elif self.dragging_clip:
self.highlighted_track_info = None
new_track_info = self.y_to_track_info(event.pos().y())
+
+ original_start_sec, _ = self.drag_original_clip_states[self.dragging_clip.id]
+
if new_track_info:
new_track_type, new_track_index = new_track_info
if new_track_type == self.dragging_clip.track_type:
self.dragging_clip.track_index = new_track_index
if self.dragging_linked_clip:
+ # For now, linked clips move to same-number track index
self.dragging_linked_clip.track_index = new_track_index
- if self.dragging_clip.track_type == 'video':
- if new_track_index > self.timeline.num_audio_tracks:
- self.timeline.num_audio_tracks = new_track_index
- self.highlighted_track_info = ('audio', new_track_index)
- else:
- if new_track_index > self.timeline.num_video_tracks:
- self.timeline.num_video_tracks = new_track_index
- self.highlighted_track_info = ('video', new_track_index)
+
delta_x = event.pos().x() - self.drag_start_pos.x()
time_delta = delta_x / self.PIXELS_PER_SECOND
- new_start_time = self.drag_original_timeline_start + time_delta
+ new_start_time = original_start_sec + time_delta
+ # Basic collision detection (can be improved)
for other_clip in self.timeline.clips:
if other_clip.id == self.dragging_clip.id: continue
if self.dragging_linked_clip and other_clip.id == self.dragging_linked_clip.id: continue
@@ -485,14 +482,130 @@ def mouseReleaseEvent(self, event: QMouseEvent):
self.dragging_playhead = False
if self.dragging_clip:
+ move_data = []
+ # Check main dragged clip
+ orig_start, orig_track = self.drag_original_clip_states[self.dragging_clip.id]
+ if orig_start != self.dragging_clip.timeline_start_sec or orig_track != self.dragging_clip.track_index:
+ move_data.append({
+ 'clip_id': self.dragging_clip.id,
+ 'old_start': orig_start, 'new_start': self.dragging_clip.timeline_start_sec,
+ 'old_track': orig_track, 'new_track': self.dragging_clip.track_index
+ })
+ # Check linked clip
+ if self.dragging_linked_clip:
+ orig_start_link, orig_track_link = self.drag_original_clip_states[self.dragging_linked_clip.id]
+ if orig_start_link != self.dragging_linked_clip.timeline_start_sec or orig_track_link != self.dragging_linked_clip.track_index:
+ move_data.append({
+ 'clip_id': self.dragging_linked_clip.id,
+ 'old_start': orig_start_link, 'new_start': self.dragging_linked_clip.timeline_start_sec,
+ 'old_track': orig_track_link, 'new_track': self.dragging_linked_clip.track_index
+ })
+
+ if move_data:
+ self.clips_moved.emit(move_data)
+
self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
self.highlighted_track_info = None
self.operation_finished.emit()
+
self.dragging_clip = None
self.dragging_linked_clip = None
+ self.drag_original_clip_states.clear()
self.update()
+ def dragEnterEvent(self, event):
+ if event.mimeData().hasFormat('application/x-vnd.video.filepath'):
+ event.acceptProposedAction()
+ else:
+ event.ignore()
+
+ def dragLeaveEvent(self, event):
+ self.drag_over_active = False
+ self.update()
+
+ def dragMoveEvent(self, event):
+ mime_data = event.mimeData()
+ if not mime_data.hasFormat('application/x-vnd.video.filepath'):
+ event.ignore()
+ return
+
+ event.acceptProposedAction()
+
+ json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8'))
+ duration = json_data['duration']
+ self.drag_has_audio = json_data['has_audio']
+
+ pos = event.position()
+ track_info = self.y_to_track_info(pos.y())
+ start_sec = self.x_to_sec(pos.x())
+
+ if track_info:
+ self.drag_over_active = True
+ track_type, track_index = track_info
+
+ width = int(duration * self.PIXELS_PER_SECOND)
+ x = self.sec_to_x(start_sec)
+
+ if track_type == 'video':
+ visual_index = self.timeline.num_video_tracks - track_index
+ y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT
+ self.drag_over_rect = QRectF(x, y, width, self.TRACK_HEIGHT)
+ if self.drag_has_audio:
+ audio_y = self.audio_tracks_y_start # Default to track 1
+ self.drag_over_audio_rect = QRectF(x, audio_y, width, self.TRACK_HEIGHT)
+
+ elif track_type == 'audio':
+ visual_index = track_index - 1
+ y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT
+ self.drag_over_audio_rect = QRectF(x, y, width, self.TRACK_HEIGHT)
+ # For now, just show video on track 1 if dragging onto audio
+ video_y = self.video_tracks_y_start + (self.timeline.num_video_tracks - 1) * self.TRACK_HEIGHT
+ self.drag_over_rect = QRectF(x, video_y, width, self.TRACK_HEIGHT)
+ else:
+ self.drag_over_active = False
+
+ self.update()
+
+ def dropEvent(self, event):
+ self.drag_over_active = False
+ self.update()
+
+ mime_data = event.mimeData()
+ if not mime_data.hasFormat('application/x-vnd.video.filepath'):
+ return
+
+ json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8'))
+ file_path = json_data['path']
+ duration = json_data['duration']
+ has_audio = json_data['has_audio']
+
+ pos = event.position()
+ start_sec = self.x_to_sec(pos.x())
+ track_info = self.y_to_track_info(pos.y())
+
+ if not track_info:
+ return
+
+ drop_track_type, drop_track_index = track_info
+ video_track_idx = 1
+ audio_track_idx = 1 if has_audio else None
+
+ if drop_track_type == 'video':
+ video_track_idx = drop_track_index
+ elif drop_track_type == 'audio':
+ audio_track_idx = drop_track_index
+
+ main_window = self.window()
+ main_window._add_clip_to_timeline(
+ source_path=file_path,
+ timeline_start_sec=start_sec,
+ duration_sec=duration,
+ clip_start_sec=0,
+ video_track_index=video_track_idx,
+ audio_track_index=audio_track_idx
+ )
+
def contextMenuEvent(self, event: 'QContextMenuEvent'):
menu = QMenu(self)
@@ -502,6 +615,8 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'):
split_all_action = menu.addAction("Split All Regions")
join_this_action = menu.addAction("Join This Region")
join_all_action = menu.addAction("Join All Regions")
+ delete_this_action = menu.addAction("Delete This Region")
+ delete_all_action = menu.addAction("Delete All Regions")
menu.addSeparator()
clear_this_action = menu.addAction("Clear This Region")
clear_all_action = menu.addAction("Clear All Regions")
@@ -509,6 +624,8 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'):
split_all_action.triggered.connect(lambda: self.split_all_regions_requested.emit(self.selection_regions))
join_this_action.triggered.connect(lambda: self.join_region_requested.emit(region_at_pos))
join_all_action.triggered.connect(lambda: self.join_all_regions_requested.emit(self.selection_regions))
+ delete_this_action.triggered.connect(lambda: self.delete_region_requested.emit(region_at_pos))
+ delete_all_action.triggered.connect(lambda: self.delete_all_regions_requested.emit(self.selection_regions))
clear_this_action.triggered.connect(lambda: self.clear_region(region_at_pos))
clear_all_action.triggered.connect(self.clear_all_regions)
@@ -521,10 +638,12 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'):
if clip_at_pos:
if not menu.isEmpty(): menu.addSeparator()
split_action = menu.addAction("Split Clip")
+ delete_action = menu.addAction("Delete Clip")
playhead_time = self.playhead_pos_sec
is_playhead_over_clip = (clip_at_pos.timeline_start_sec < playhead_time < clip_at_pos.timeline_end_sec)
split_action.setEnabled(is_playhead_over_clip)
split_action.triggered.connect(lambda: self.split_requested.emit(clip_at_pos))
+ delete_action.triggered.connect(lambda: self.delete_clip_requested.emit(clip_at_pos))
self.context_menu_requested.emit(menu, event)
@@ -540,22 +659,111 @@ def clear_all_regions(self):
self.selection_regions.clear()
self.update()
+
class SettingsDialog(QDialog):
def __init__(self, parent_settings, parent=None):
super().__init__(parent)
self.setWindowTitle("Settings")
self.setMinimumWidth(350)
layout = QVBoxLayout(self)
- self.seek_anywhere_checkbox = QCheckBox("Allow seeking by clicking anywhere on the timeline")
- self.seek_anywhere_checkbox.setChecked(parent_settings.get("allow_seek_anywhere", False))
- layout.addWidget(self.seek_anywhere_checkbox)
+ self.confirm_on_exit_checkbox = QCheckBox("Confirm before exiting")
+ self.confirm_on_exit_checkbox.setChecked(parent_settings.get("confirm_on_exit", True))
+ layout.addWidget(self.confirm_on_exit_checkbox)
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)
layout.addWidget(button_box)
def get_settings(self):
- return {"allow_seek_anywhere": self.seek_anywhere_checkbox.isChecked()}
+ return {
+ "confirm_on_exit": self.confirm_on_exit_checkbox.isChecked()
+ }
+
+class MediaListWidget(QListWidget):
+ def __init__(self, main_window, parent=None):
+ super().__init__(parent)
+ self.main_window = main_window
+ self.setDragEnabled(True)
+ self.setAcceptDrops(False)
+ self.setSelectionMode(QListWidget.SelectionMode.ExtendedSelection)
+
+ def startDrag(self, supportedActions):
+ drag = QDrag(self)
+ mime_data = QMimeData()
+
+ item = self.currentItem()
+ if not item: return
+
+ path = item.data(Qt.ItemDataRole.UserRole)
+ media_info = self.main_window.media_properties.get(path)
+ if not media_info: return
+
+ payload = {
+ "path": path,
+ "duration": media_info['duration'],
+ "has_audio": media_info['has_audio']
+ }
+
+ mime_data.setData('application/x-vnd.video.filepath', QByteArray(json.dumps(payload).encode('utf-8')))
+ drag.setMimeData(mime_data)
+ drag.exec(Qt.DropAction.CopyAction)
+
+class ProjectMediaWidget(QWidget):
+ media_removed = pyqtSignal(str)
+ add_media_requested = pyqtSignal()
+ add_to_timeline_requested = pyqtSignal(str)
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.main_window = parent
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(2, 2, 2, 2)
+
+ self.media_list = MediaListWidget(self.main_window, self)
+ self.media_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu)
+ self.media_list.customContextMenuRequested.connect(self.show_context_menu)
+ layout.addWidget(self.media_list)
+
+ button_layout = QHBoxLayout()
+ add_button = QPushButton("Add")
+ remove_button = QPushButton("Remove")
+ button_layout.addWidget(add_button)
+ button_layout.addWidget(remove_button)
+ layout.addLayout(button_layout)
+
+ add_button.clicked.connect(self.add_media_requested.emit)
+ remove_button.clicked.connect(self.remove_selected_media)
+ def show_context_menu(self, pos):
+ item = self.media_list.itemAt(pos)
+ if not item:
+ return
+
+ menu = QMenu()
+ add_action = menu.addAction("Add to timeline at playhead")
+
+ action = menu.exec(self.media_list.mapToGlobal(pos))
+
+ if action == add_action:
+ file_path = item.data(Qt.ItemDataRole.UserRole)
+ self.add_to_timeline_requested.emit(file_path)
+
+ def add_media_item(self, file_path):
+ if not any(self.media_list.item(i).data(Qt.ItemDataRole.UserRole) == file_path for i in range(self.media_list.count())):
+ item = QListWidgetItem(os.path.basename(file_path))
+ item.setData(Qt.ItemDataRole.UserRole, file_path)
+ self.media_list.addItem(item)
+
+ def remove_selected_media(self):
+ selected_items = self.media_list.selectedItems()
+ if not selected_items: return
+
+ for item in selected_items:
+ file_path = item.data(Qt.ItemDataRole.UserRole)
+ self.media_removed.emit(file_path)
+ self.media_list.takeItem(self.media_list.row(item))
+
+ def clear_list(self):
+ self.media_list.clear()
class MainWindow(QMainWindow):
def __init__(self, project_to_load=None):
@@ -565,6 +773,9 @@ def __init__(self, project_to_load=None):
self.setDockOptions(QMainWindow.DockOption.AnimatedDocks | QMainWindow.DockOption.AllowNestedDocks)
self.timeline = Timeline()
+ self.undo_stack = UndoStack()
+ self.media_pool = []
+ self.media_properties = {} # To store duration, has_audio etc.
self.export_thread = None
self.export_worker = None
self.current_project_path = None
@@ -583,6 +794,30 @@ def __init__(self, project_to_load=None):
self.playback_process = None
self.playback_clip = None
+ self._setup_ui()
+ self._connect_signals()
+
+ self.plugin_manager.load_enabled_plugins_from_settings(self.settings.get("enabled_plugins", []))
+ self._apply_loaded_settings()
+ self.seek_preview(0)
+
+ if not self.settings_file_was_loaded: self._save_settings()
+ if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load))
+
+ def _get_current_timeline_state(self):
+ """Helper to capture a snapshot of the timeline for undo/redo."""
+ return (
+ copy.deepcopy(self.timeline.clips),
+ self.timeline.num_video_tracks,
+ self.timeline.num_audio_tracks
+ )
+
+ def _setup_ui(self):
+ self.media_dock = QDockWidget("Project Media", self)
+ self.project_media_widget = ProjectMediaWidget(self)
+ self.media_dock.setWidget(self.project_media_widget)
+ self.addDockWidget(Qt.DockWidgetArea.LeftDockWidgetArea, self.media_dock)
+
self.splitter = QSplitter(Qt.Orientation.Vertical)
self.preview_widget = QLabel()
@@ -592,14 +827,14 @@ def __init__(self, project_to_load=None):
self.preview_widget.setStyleSheet("background-color: black; color: white;")
self.splitter.addWidget(self.preview_widget)
- self.timeline_widget = TimelineWidget(self.timeline, self.settings)
+ self.timeline_widget = TimelineWidget(self.timeline, self.settings, self)
self.timeline_scroll_area = QScrollArea()
self.timeline_scroll_area.setWidgetResizable(True)
self.timeline_scroll_area.setWidget(self.timeline_widget)
self.timeline_scroll_area.setFrameShape(QFrame.Shape.NoFrame)
self.timeline_scroll_area.setMinimumHeight(250)
self.splitter.addWidget(self.timeline_scroll_area)
-
+
container_widget = QWidget()
main_layout = QVBoxLayout(container_widget)
main_layout.setContentsMargins(0,0,0,0)
@@ -634,7 +869,8 @@ def __init__(self, project_to_load=None):
self.managed_widgets = {
'preview': {'widget': self.preview_widget, 'name': 'Video Preview', 'action': None},
- 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None}
+ 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None},
+ 'project_media': {'widget': self.media_dock, 'name': 'Project Media', 'action': None}
}
self.plugin_menu_actions = {}
self.windows_menu = None
@@ -643,14 +879,20 @@ def __init__(self, project_to_load=None):
self.splitter_save_timer = QTimer(self)
self.splitter_save_timer.setSingleShot(True)
self.splitter_save_timer.timeout.connect(self._save_settings)
+
+ def _connect_signals(self):
self.splitter.splitterMoved.connect(self.on_splitter_moved)
self.timeline_widget.split_requested.connect(self.split_clip_at_playhead)
+ self.timeline_widget.delete_clip_requested.connect(self.delete_clip)
self.timeline_widget.playhead_moved.connect(self.seek_preview)
+ self.timeline_widget.clips_moved.connect(self.on_clips_moved)
self.timeline_widget.split_region_requested.connect(self.on_split_region)
self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions)
self.timeline_widget.join_region_requested.connect(self.on_join_region)
self.timeline_widget.join_all_regions_requested.connect(self.on_join_all_regions)
+ self.timeline_widget.delete_region_requested.connect(self.on_delete_region)
+ self.timeline_widget.delete_all_regions_requested.connect(self.on_delete_all_regions)
self.timeline_widget.add_track.connect(self.add_track)
self.timeline_widget.remove_track.connect(self.remove_track)
self.timeline_widget.operation_finished.connect(self.prune_empty_tracks)
@@ -661,16 +903,57 @@ def __init__(self, project_to_load=None):
self.frame_forward_button.clicked.connect(lambda: self.step_frame(1))
self.playback_timer.timeout.connect(self.advance_playback_frame)
- self.plugin_manager.load_enabled_plugins_from_settings(self.settings.get("enabled_plugins", []))
- self._apply_loaded_settings()
- self.seek_preview(0)
+ self.project_media_widget.add_media_requested.connect(self.add_video_clip)
+ self.project_media_widget.media_removed.connect(self.on_media_removed_from_pool)
+ self.project_media_widget.add_to_timeline_requested.connect(self.on_add_to_timeline_at_playhead)
- if not self.settings_file_was_loaded: self._save_settings()
- if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load))
+ self.undo_stack.history_changed.connect(self.update_undo_redo_actions)
+ self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo)
+
+ def on_timeline_changed_by_undo(self):
+ """Called when undo/redo modifies the timeline."""
+ self.prune_empty_tracks()
+ self.timeline_widget.update()
+ # You may want to update other UI elements here as well
+ self.status_label.setText("Operation undone/redone.")
+
+ def update_undo_redo_actions(self):
+ self.undo_action.setEnabled(self.undo_stack.can_undo())
+ self.undo_action.setText(f"Undo {self.undo_stack.undo_text()}" if self.undo_stack.can_undo() else "Undo")
+
+ self.redo_action.setEnabled(self.undo_stack.can_redo())
+ self.redo_action.setText(f"Redo {self.undo_stack.redo_text()}" if self.undo_stack.can_redo() else "Redo")
+
+ def on_clips_moved(self, move_data):
+ command = MoveClipsCommand("Move Clip", self.timeline, move_data)
+ # The change is already applied visually, so we push without redoing
+ self.undo_stack.undo_stack.append(command)
+ self.undo_stack.redo_stack.clear()
+ self.undo_stack.history_changed.emit()
+ self.status_label.setText("Clip moved.")
+
+ def on_add_to_timeline_at_playhead(self, file_path):
+ media_info = self.media_properties.get(file_path)
+ if not media_info:
+ self.status_label.setText(f"Error: Could not find properties for {os.path.basename(file_path)}")
+ return
+
+ playhead_pos = self.timeline_widget.playhead_pos_sec
+ duration = media_info['duration']
+ has_audio = media_info['has_audio']
+
+ self._add_clip_to_timeline(
+ source_path=file_path,
+ timeline_start_sec=playhead_pos,
+ duration_sec=duration,
+ clip_start_sec=0.0,
+ video_track_index=1,
+ audio_track_index=1 if has_audio else None
+ )
def prune_empty_tracks(self):
pruned_something = False
-
+ # Prune Video Tracks
while self.timeline.num_video_tracks > 1:
highest_track_index = self.timeline.num_video_tracks
is_track_occupied = any(c for c in self.timeline.clips
@@ -680,7 +963,8 @@ def prune_empty_tracks(self):
else:
self.timeline.num_video_tracks -= 1
pruned_something = True
-
+
+ # Prune Audio Tracks
while self.timeline.num_audio_tracks > 1:
highest_track_index = self.timeline.num_audio_tracks
is_track_occupied = any(c for c in self.timeline.clips
@@ -690,42 +974,64 @@ def prune_empty_tracks(self):
else:
self.timeline.num_audio_tracks -= 1
pruned_something = True
-
+
if pruned_something:
self.timeline_widget.update()
-
def add_track(self, track_type):
+ old_state = self._get_current_timeline_state()
if track_type == 'video':
self.timeline.num_video_tracks += 1
elif track_type == 'audio':
self.timeline.num_audio_tracks += 1
- self.timeline_widget.update()
+ new_state = self._get_current_timeline_state()
+
+ command = TimelineStateChangeCommand(f"Add {track_type.capitalize()} Track", self.timeline, *old_state, *new_state)
+ self.undo_stack.push(command)
+ # Manually undo the direct change, because push() will redo() it.
+ command.undo()
+ self.undo_stack.push(command)
def remove_track(self, track_type):
+ old_state = self._get_current_timeline_state()
if track_type == 'video' and self.timeline.num_video_tracks > 1:
self.timeline.num_video_tracks -= 1
elif track_type == 'audio' and self.timeline.num_audio_tracks > 1:
self.timeline.num_audio_tracks -= 1
- self.timeline_widget.update()
+ else:
+ return # No change made
+ new_state = self._get_current_timeline_state()
+
+ command = TimelineStateChangeCommand(f"Remove {track_type.capitalize()} Track", self.timeline, *old_state, *new_state)
+ command.undo() # Invert state because we already made the change
+ self.undo_stack.push(command)
+
def _create_menu_bar(self):
menu_bar = self.menuBar()
file_menu = menu_bar.addMenu("&File")
new_action = QAction("&New Project", self); new_action.triggered.connect(self.new_project)
open_action = QAction("&Open Project...", self); open_action.triggered.connect(self.open_project)
+ self.recent_menu = file_menu.addMenu("Recent")
save_action = QAction("&Save Project As...", self); save_action.triggered.connect(self.save_project_as)
- add_video_action = QAction("&Add Video...", self); add_video_action.triggered.connect(self.add_video_clip)
+ add_video_action = QAction("&Add Media to Project...", self); add_video_action.triggered.connect(self.add_video_clip)
export_action = QAction("&Export Video...", self); export_action.triggered.connect(self.export_video)
settings_action = QAction("Se&ttings...", self); settings_action.triggered.connect(self.open_settings_dialog)
exit_action = QAction("E&xit", self); exit_action.triggered.connect(self.close)
- file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addAction(save_action)
+ file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addSeparator()
+ file_menu.addAction(save_action)
file_menu.addSeparator(); file_menu.addAction(add_video_action); file_menu.addAction(export_action)
file_menu.addSeparator(); file_menu.addAction(settings_action); file_menu.addSeparator(); file_menu.addAction(exit_action)
-
+ self._update_recent_files_menu()
+
edit_menu = menu_bar.addMenu("&Edit")
+ self.undo_action = QAction("Undo", self); self.undo_action.setShortcut("Ctrl+Z"); self.undo_action.triggered.connect(self.undo_stack.undo)
+ self.redo_action = QAction("Redo", self); self.redo_action.setShortcut("Ctrl+Y"); self.redo_action.triggered.connect(self.undo_stack.redo)
+ edit_menu.addAction(self.undo_action); edit_menu.addAction(self.redo_action); edit_menu.addSeparator()
+
split_action = QAction("Split Clip at Playhead", self); split_action.triggered.connect(self.split_clip_at_playhead)
edit_menu.addAction(split_action)
+ self.update_undo_redo_actions() # Set initial state
plugins_menu = menu_bar.addMenu("&Plugins")
for name, data in self.plugin_manager.plugins.items():
@@ -742,8 +1048,14 @@ def _create_menu_bar(self):
self.windows_menu = menu_bar.addMenu("&Windows")
for key, data in self.managed_widgets.items():
+ if data['widget'] is self.preview_widget: continue # Preview is part of splitter, not a dock
action = QAction(data['name'], self, checkable=True)
- action.toggled.connect(lambda checked, k=key: self.toggle_widget_visibility(k, checked))
+ if hasattr(data['widget'], 'visibilityChanged'):
+ action.toggled.connect(data['widget'].setVisible)
+ data['widget'].visibilityChanged.connect(action.setChecked)
+ else:
+ action.toggled.connect(lambda checked, k=key: self.toggle_widget_visibility(k, checked))
+
data['action'] = action
self.windows_menu.addAction(action)
@@ -783,6 +1095,10 @@ def _set_project_properties_from_clip(self, source_path):
return False
def get_frame_data_at_time(self, time_sec):
+ """
+ Plugin API: Extracts raw frame data at a specific time.
+ Returns a tuple of (bytes, width, height) or (None, 0, 0) on failure.
+ """
clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None)
if not clip_at_time:
return (None, 0, 0)
@@ -858,7 +1174,7 @@ def advance_playback_frame(self):
def _load_settings(self):
self.settings_file_was_loaded = False
- defaults = {"allow_seek_anywhere": False, "window_visibility": {"preview": True, "timeline": True}, "splitter_state": None, "enabled_plugins": []}
+ defaults = {"window_visibility": {"project_media": True}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True}
if os.path.exists(self.settings_file):
try:
with open(self.settings_file, "r") as f: self.settings = json.load(f)
@@ -874,33 +1190,27 @@ def _save_settings(self):
key: data['action'].isChecked()
for key, data in self.managed_widgets.items() if data.get('action')
}
-
self.settings["window_visibility"] = visibility_to_save
self.settings['enabled_plugins'] = self.plugin_manager.get_enabled_plugin_names()
try:
- with open(self.settings_file, "w") as f:
- json.dump(self.settings, f, indent=4)
- except IOError as e:
- print(f"Error saving settings: {e}")
+ with open(self.settings_file, "w") as f: json.dump(self.settings, f, indent=4)
+ except IOError as e: print(f"Error saving settings: {e}")
def _apply_loaded_settings(self):
visibility_settings = self.settings.get("window_visibility", {})
for key, data in self.managed_widgets.items():
- if data.get('plugin'):
- continue
-
+ if data.get('plugin'): continue
is_visible = visibility_settings.get(key, True)
- data['widget'].setVisible(is_visible)
+ if data['widget'] is not self.preview_widget: # preview handled by splitter state
+ data['widget'].setVisible(is_visible)
if data['action']: data['action'].setChecked(is_visible)
-
splitter_state = self.settings.get("splitter_state")
if splitter_state: self.splitter.restoreState(QByteArray.fromHex(splitter_state.encode('ascii')))
def on_splitter_moved(self, pos, index): self.splitter_save_timer.start(500)
def toggle_widget_visibility(self, key, checked):
- if self.is_shutting_down:
- return
+ if self.is_shutting_down: return
if key in self.managed_widgets:
self.managed_widgets[key]['widget'].setVisible(checked)
self._save_settings()
@@ -912,19 +1222,25 @@ def open_settings_dialog(self):
def new_project(self):
self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1
+ self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list()
self.current_project_path = None; self.stop_playback(); self.timeline_widget.update()
- self.status_label.setText("New project created. Add video clips to begin.")
+ self.undo_stack = UndoStack()
+ self.undo_stack.history_changed.connect(self.update_undo_redo_actions)
+ self.update_undo_redo_actions()
+ self.status_label.setText("New project created. Add media to begin.")
def save_project_as(self):
path, _ = QFileDialog.getSaveFileName(self, "Save Project", "", "JSON Project Files (*.json)")
if not path: return
project_data = {
+ "media_pool": self.media_pool,
"clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "group_id": c.group_id} for c in self.timeline.clips],
"settings": {"num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks}
}
try:
with open(path, "w") as f: json.dump(project_data, f, indent=4)
self.current_project_path = path; self.status_label.setText(f"Project saved to {os.path.basename(path)}")
+ self._add_to_recent_files(path)
except Exception as e: self.status_label.setText(f"Error saving project: {e}")
def open_project(self):
@@ -934,10 +1250,14 @@ def open_project(self):
def _load_project_from_path(self, path):
try:
with open(path, "r") as f: project_data = json.load(f)
- self.timeline.clips.clear()
+ self.new_project() # Clear existing state
+
+ self.media_pool = project_data.get("media_pool", [])
+ for p in self.media_pool: self._add_media_to_pool(p)
+
for clip_data in project_data["clips"]:
if not os.path.exists(clip_data["source_path"]):
- self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.timeline.clips.clear(); self.timeline_widget.update(); return
+ self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.new_project(); return
self.timeline.add_clip(TimelineClip(**clip_data))
project_settings = project_data.get("settings", {})
@@ -949,35 +1269,77 @@ def _load_project_from_path(self, path):
self.prune_empty_tracks()
self.timeline_widget.update(); self.stop_playback()
self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.")
+ self._add_to_recent_files(path)
except Exception as e: self.status_label.setText(f"Error opening project: {e}")
- def add_video_clip(self):
- file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.mov *.avi)")
- if not file_path: return
- if not self.timeline.clips:
- if not self._set_project_properties_from_clip(file_path): self.status_label.setText("Error: Could not determine video properties from file."); return
+ def _add_to_recent_files(self, path):
+ recent = self.settings.get("recent_files", [])
+ if path in recent: recent.remove(path)
+ recent.insert(0, path)
+ self.settings["recent_files"] = recent[:10]
+ self._update_recent_files_menu()
+ self._save_settings()
+
+ def _update_recent_files_menu(self):
+ self.recent_menu.clear()
+ recent_files = self.settings.get("recent_files", [])
+ for path in recent_files:
+ if os.path.exists(path):
+ action = QAction(os.path.basename(path), self)
+ action.triggered.connect(lambda checked, p=path: self._load_project_from_path(p))
+ self.recent_menu.addAction(action)
+
+ def _add_media_to_pool(self, file_path):
+ if file_path in self.media_pool: return True
try:
self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents()
probe = ffmpeg.probe(file_path)
video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
if not video_stream: raise ValueError("No video stream found.")
+
duration = float(video_stream.get('duration', probe['format'].get('duration', 0)))
has_audio = any(s['codec_type'] == 'audio' for s in probe['streams'])
- timeline_start = self.timeline.get_total_duration()
-
- self._add_clip_to_timeline(
- source_path=file_path,
- timeline_start_sec=timeline_start,
- duration_sec=duration,
- clip_start_sec=0,
- video_track_index=1,
- audio_track_index=1 if has_audio else None
- )
+
+ self.media_properties[file_path] = {'duration': duration, 'has_audio': has_audio}
+ self.media_pool.append(file_path)
+ self.project_media_widget.add_media_item(file_path)
+
+ self.status_label.setText(f"Added {os.path.basename(file_path)} to project.")
+ return True
+ except Exception as e:
+ self.status_label.setText(f"Error probing file: {e}")
+ return False
+
+ def on_media_removed_from_pool(self, file_path):
+ old_state = self._get_current_timeline_state()
+
+ if file_path in self.media_pool: self.media_pool.remove(file_path)
+ if file_path in self.media_properties: del self.media_properties[file_path]
+
+ clips_to_remove = [c for c in self.timeline.clips if c.source_path == file_path]
+ for clip in clips_to_remove: self.timeline.clips.remove(clip)
+
+ new_state = self._get_current_timeline_state()
+ command = TimelineStateChangeCommand("Remove Media From Project", self.timeline, *old_state, *new_state)
+ command.undo()
+ self.undo_stack.push(command)
+
- self.timeline_widget.update(); self.seek_preview(self.timeline_widget.playhead_pos_sec)
- except Exception as e: self.status_label.setText(f"Error adding file: {e}")
+ def add_video_clip(self):
+ file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.mov *.avi)")
+ if not file_path: return
+
+ if not self.timeline.clips and not self.media_pool:
+ if not self._set_project_properties_from_clip(file_path):
+ self.status_label.setText("Error: Could not determine video properties from file."); return
+
+ if self._add_media_to_pool(file_path):
+ self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.")
+ else:
+ self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.")
def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, clip_start_sec=0.0, video_track_index=None, audio_track_index=None):
+ old_state = self._get_current_timeline_state()
group_id = str(uuid.uuid4())
if video_track_index is not None:
@@ -988,14 +1350,16 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, c
audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', group_id)
self.timeline.add_clip(audio_clip)
- self.timeline_widget.update()
- self.status_label.setText(f"Added {os.path.basename(source_path)}.")
+ new_state = self._get_current_timeline_state()
+ command = TimelineStateChangeCommand("Add Clip", self.timeline, *old_state, *new_state)
+ # We already performed the action, so we need to undo it before letting the stack redo it.
+ command.undo()
+ self.undo_stack.push(command)
def _split_at_time(self, clip_to_split, time_sec, new_group_id=None):
if not (clip_to_split.timeline_start_sec < time_sec < clip_to_split.timeline_end_sec): return False
split_point = time_sec - clip_to_split.timeline_start_sec
orig_dur = clip_to_split.duration_sec
-
group_id_for_new_clip = new_group_id if new_group_id is not None else clip_to_split.group_id
new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, group_id_for_new_clip)
@@ -1006,88 +1370,199 @@ def _split_at_time(self, clip_to_split, time_sec, new_group_id=None):
def split_clip_at_playhead(self, clip_to_split=None):
playhead_time = self.timeline_widget.playhead_pos_sec
if not clip_to_split:
- clip_to_split = next((c for c in self.timeline.clips if c.timeline_start_sec < playhead_time < c.timeline_end_sec), None)
-
- if not clip_to_split:
- self.status_label.setText("Playhead is not over a clip to split.")
- return
-
+ clips_at_playhead = [c for c in self.timeline.clips if c.timeline_start_sec < playhead_time < c.timeline_end_sec]
+ if not clips_at_playhead:
+ self.status_label.setText("Playhead is not over a clip to split.")
+ return
+ clip_to_split = clips_at_playhead[0]
+
+ old_state = self._get_current_timeline_state()
+
linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_split.group_id and c.id != clip_to_split.id), None)
-
new_right_side_group_id = str(uuid.uuid4())
split1 = self._split_at_time(clip_to_split, playhead_time, new_group_id=new_right_side_group_id)
- split2 = True
if linked_clip:
- split2 = self._split_at_time(linked_clip, playhead_time, new_group_id=new_right_side_group_id)
+ self._split_at_time(linked_clip, playhead_time, new_group_id=new_right_side_group_id)
- if split1 and split2:
- self.timeline_widget.update()
- self.status_label.setText("Clip split.")
+ if split1:
+ new_state = self._get_current_timeline_state()
+ command = TimelineStateChangeCommand("Split Clip", self.timeline, *old_state, *new_state)
+ command.undo()
+ self.undo_stack.push(command)
else:
self.status_label.setText("Failed to split clip.")
- def on_split_region(self, region):
- start_sec, end_sec = region;
- clips = list(self.timeline.clips)
- for clip in clips: self._split_at_time(clip, end_sec)
- for clip in clips: self._split_at_time(clip, start_sec)
- self.timeline_widget.clear_region(region); self.timeline_widget.update(); self.status_label.setText("Region split.")
- def on_split_all_regions(self, regions):
- split_points = set()
- for start, end in regions: split_points.add(start); split_points.add(end)
- for point in sorted(list(split_points)):
- group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
- new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
-
- for clip in list(self.timeline.clips):
- if clip.group_id in new_group_ids:
- self._split_at_time(clip, point, new_group_ids[clip.group_id])
+ def delete_clip(self, clip_to_delete):
+ old_state = self._get_current_timeline_state()
- self.timeline_widget.clear_all_regions(); self.timeline_widget.update(); self.status_label.setText("All regions split.")
-
- def on_join_region(self, region):
- start_sec, end_sec = region; duration_to_remove = end_sec - start_sec
- if duration_to_remove <= 0.01: return
+ linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_delete.group_id and c.id != clip_to_delete.id), None)
- for point in [start_sec, end_sec]:
- group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
- new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
- for clip in list(self.timeline.clips):
- if clip.group_id in new_group_ids:
- self._split_at_time(clip, point, new_group_ids[clip.group_id])
-
- clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
- for clip in clips_to_remove: self.timeline.clips.remove(clip)
- for clip in self.timeline.clips:
- if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove
- self.timeline.clips.sort(key=lambda c: c.timeline_start_sec); self.timeline_widget.clear_region(region)
- self.timeline_widget.update(); self.status_label.setText("Region joined (content removed).")
+ if clip_to_delete in self.timeline.clips: self.timeline.clips.remove(clip_to_delete)
+ if linked_clip and linked_clip in self.timeline.clips: self.timeline.clips.remove(linked_clip)
+
+ new_state = self._get_current_timeline_state()
+ command = TimelineStateChangeCommand("Delete Clip", self.timeline, *old_state, *new_state)
+ command.undo()
+ self.undo_stack.push(command)
self.prune_empty_tracks()
- def on_join_all_regions(self, regions):
- for region in sorted(regions, key=lambda r: r[0], reverse=True):
- start_sec, end_sec = region; duration_to_remove = end_sec - start_sec
- if duration_to_remove <= 0.01: continue
+ def _perform_complex_timeline_change(self, description, change_function):
+ """Wrapper for complex operations to make them undoable."""
+ old_state = self._get_current_timeline_state()
+
+ # Execute the provided function which will modify the timeline
+ change_function()
+
+ new_state = self._get_current_timeline_state()
+
+ # If no change occurred, don't add to undo stack
+ if old_state[0] == new_state[0] and old_state[1] == new_state[1] and old_state[2] == new_state[2]:
+ return
+
+ command = TimelineStateChangeCommand(description, self.timeline, *old_state, *new_state)
+ # The change was already made, so we need to "undo" it before letting the stack "redo" it.
+ command.undo()
+ self.undo_stack.push(command)
- for point in [start_sec, end_sec]:
+ def on_split_region(self, region):
+ def action():
+ start_sec, end_sec = region
+ clips = list(self.timeline.clips) # Work on a copy
+ for clip in clips: self._split_at_time(clip, end_sec)
+ for clip in clips: self._split_at_time(clip, start_sec)
+ self.timeline_widget.clear_region(region)
+ self._perform_complex_timeline_change("Split Region", action)
+
+ def on_split_all_regions(self, regions):
+ def action():
+ split_points = set()
+ for start, end in regions:
+ split_points.add(start)
+ split_points.add(end)
+
+ for point in sorted(list(split_points)):
group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids:
self._split_at_time(clip, point, new_group_ids[clip.group_id])
+ self.timeline_widget.clear_all_regions()
+ self._perform_complex_timeline_change("Split All Regions", action)
+
+ def on_join_region(self, region):
+ def action():
+ start_sec, end_sec = region
+ duration_to_remove = end_sec - start_sec
+ if duration_to_remove <= 0.01: return
+
+ # Split any clips that cross the region boundaries
+ for point in [start_sec, end_sec]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
+ for clip in list(self.timeline.clips):
+ if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
+
+ # Remove clips fully inside the region
+ clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
+ for clip in clips_to_remove: self.timeline.clips.remove(clip)
+
+ # Ripple delete: move subsequent clips to the left
+ for clip in self.timeline.clips:
+ if clip.timeline_start_sec >= end_sec:
+ clip.timeline_start_sec -= duration_to_remove
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline_widget.clear_region(region)
+ self._perform_complex_timeline_change("Join Region", action)
+
+ def on_join_all_regions(self, regions):
+ def action():
+ # Process regions from end to start to avoid index issues
+ for region in sorted(regions, key=lambda r: r[0], reverse=True):
+ start_sec, end_sec = region
+ duration_to_remove = end_sec - start_sec
+ if duration_to_remove <= 0.01: continue
+
+ # Split clips at boundaries
+ for point in [start_sec, end_sec]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
+ for clip in list(self.timeline.clips):
+ if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
+
+ # Remove clips inside region
+ clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
+ for clip in clips_to_remove:
+ try: self.timeline.clips.remove(clip)
+ except ValueError: pass # Already removed
+
+ # Ripple
+ for clip in self.timeline.clips:
+ if clip.timeline_start_sec >= end_sec:
+ clip.timeline_start_sec -= duration_to_remove
+
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline_widget.clear_all_regions()
+ self._perform_complex_timeline_change("Join All Regions", action)
+
+ def on_delete_region(self, region):
+ def action():
+ start_sec, end_sec = region
+ duration_to_remove = end_sec - start_sec
+ if duration_to_remove <= 0.01: return
+
+ # Split any clips that cross the region boundaries
+ for point in [start_sec, end_sec]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
+ for clip in list(self.timeline.clips):
+ if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
+
+ # Remove clips fully inside the region
clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
- for clip in clips_to_remove:
- try: self.timeline.clips.remove(clip)
- except ValueError: pass
+ for clip in clips_to_remove: self.timeline.clips.remove(clip)
+
+ # Ripple delete: move subsequent clips to the left
for clip in self.timeline.clips:
- if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove
+ if clip.timeline_start_sec >= end_sec:
+ clip.timeline_start_sec -= duration_to_remove
+
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline_widget.clear_region(region)
+ self._perform_complex_timeline_change("Delete Region", action)
+
+ def on_delete_all_regions(self, regions):
+ def action():
+ # Process regions from end to start to avoid index issues
+ for region in sorted(regions, key=lambda r: r[0], reverse=True):
+ start_sec, end_sec = region
+ duration_to_remove = end_sec - start_sec
+ if duration_to_remove <= 0.01: continue
+
+ # Split clips at boundaries
+ for point in [start_sec, end_sec]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
+ for clip in list(self.timeline.clips):
+ if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
+
+ # Remove clips inside region
+ clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
+ for clip in clips_to_remove:
+ try: self.timeline.clips.remove(clip)
+ except ValueError: pass # Already removed
+
+ # Ripple
+ for clip in self.timeline.clips:
+ if clip.timeline_start_sec >= end_sec:
+ clip.timeline_start_sec -= duration_to_remove
+
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline_widget.clear_all_regions()
+ self._perform_complex_timeline_change("Delete All Regions", action)
- self.timeline.clips.sort(key=lambda c: c.timeline_start_sec); self.timeline_widget.clear_all_regions()
- self.timeline_widget.update(); self.status_label.setText("All regions joined.")
- self.prune_empty_tracks()
def export_video(self):
if not self.timeline.clips: self.status_label.setText("Timeline is empty."); return
@@ -1174,45 +1649,29 @@ def on_thread_finished_cleanup(self):
def add_dock_widget(self, plugin_instance, widget, title, area=Qt.DockWidgetArea.RightDockWidgetArea, show_on_creation=True):
widget_key = f"plugin_{plugin_instance.name}_{title}".replace(' ', '_').lower()
-
dock = QDockWidget(title, self)
dock.setWidget(widget)
self.addDockWidget(area, dock)
-
visibility_settings = self.settings.get("window_visibility", {})
initial_visibility = visibility_settings.get(widget_key, show_on_creation)
-
dock.setVisible(initial_visibility)
-
action = QAction(title, self, checkable=True)
- action.toggled.connect(lambda checked, k=widget_key: self.toggle_widget_visibility(k, checked))
+ action.toggled.connect(dock.setVisible)
dock.visibilityChanged.connect(action.setChecked)
-
action.setChecked(dock.isVisible())
-
self.windows_menu.addAction(action)
-
- self.managed_widgets[widget_key] = {
- 'widget': dock,
- 'name': title,
- 'action': action,
- 'plugin': plugin_instance.name
- }
-
+ self.managed_widgets[widget_key] = {'widget': dock, 'name': title, 'action': action, 'plugin': plugin_instance.name}
return dock
def update_plugin_ui_visibility(self, plugin_name, is_enabled):
for key, data in self.managed_widgets.items():
if data.get('plugin') == plugin_name:
data['action'].setVisible(is_enabled)
- if not is_enabled:
- data['widget'].hide()
+ if not is_enabled: data['widget'].hide()
def toggle_plugin(self, name, checked):
- if checked:
- self.plugin_manager.enable_plugin(name)
- else:
- self.plugin_manager.disable_plugin(name)
+ if checked: self.plugin_manager.enable_plugin(name)
+ else: self.plugin_manager.disable_plugin(name)
self._save_settings()
def toggle_plugin_action(self, name, checked):
@@ -1228,6 +1687,27 @@ def open_manage_plugins_dialog(self):
dialog.exec()
def closeEvent(self, event):
+ if self.settings.get("confirm_on_exit", True):
+ msg_box = QMessageBox(self)
+ msg_box.setWindowTitle("Confirm Exit")
+ msg_box.setText("Are you sure you want to exit?")
+ msg_box.setInformativeText("Any unsaved changes will be lost.")
+ msg_box.setIcon(QMessageBox.Icon.Question)
+ msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No)
+ msg_box.setDefaultButton(QMessageBox.StandardButton.No)
+
+ dont_ask_cb = QCheckBox("Don't ask again")
+ msg_box.setCheckBox(dont_ask_cb)
+
+ reply = msg_box.exec()
+
+ if dont_ask_cb.isChecked():
+ self.settings['confirm_on_exit'] = False
+
+ if reply == QMessageBox.StandardButton.No:
+ event.ignore()
+ return
+
self.is_shutting_down = True
self._save_settings()
self._stop_playback_stream()
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index 873cd5e56..56b94efe6 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -139,6 +139,7 @@ def generate(self):
if response.status_code == 200:
if payload['start_generation']:
self.status_label.setText("Status: Parameters set. Generation sent. Polling...")
+ self.start_polling()
else:
self.status_label.setText("Status: Parameters set. Waiting for manual start.")
else:
@@ -154,6 +155,7 @@ def poll_for_output(self):
latest_path = data.get("latest_output_path")
if latest_path and latest_path != self.last_known_output:
+ self.stop_polling()
self.last_known_output = latest_path
self.output_label.setText(f"Latest Output File:\n{latest_path}")
self.status_label.setText("Status: New output received! Inserting clip...")
@@ -182,7 +184,6 @@ def enable(self):
self.dock_widget.hide()
self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu)
- self.client_widget.start_polling()
def disable(self):
try:
@@ -216,6 +217,8 @@ def setup_generator_for_region(self, region):
self.active_region = region
start_sec, end_sec = region
+ self.client_widget.check_server_status()
+
start_data, w, h = self.app.get_frame_data_at_time(start_sec)
end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
diff --git a/videoeditor/undo.py b/videoeditor/undo.py
new file mode 100644
index 000000000..b43f4fb46
--- /dev/null
+++ b/videoeditor/undo.py
@@ -0,0 +1,114 @@
+import copy
+from PyQt6.QtCore import QObject, pyqtSignal
+
+class UndoCommand:
+ def __init__(self, description=""):
+ self.description = description
+
+ def undo(self):
+ raise NotImplementedError
+
+ def redo(self):
+ raise NotImplementedError
+
+class CompositeCommand(UndoCommand):
+ def __init__(self, description, commands):
+ super().__init__(description)
+ self.commands = commands
+
+ def undo(self):
+ for cmd in reversed(self.commands):
+ cmd.undo()
+
+ def redo(self):
+ for cmd in self.commands:
+ cmd.redo()
+
+class UndoStack(QObject):
+ history_changed = pyqtSignal()
+ timeline_changed = pyqtSignal()
+
+ def __init__(self):
+ super().__init__()
+ self.undo_stack = []
+ self.redo_stack = []
+
+ def push(self, command):
+ self.undo_stack.append(command)
+ self.redo_stack.clear()
+ command.redo()
+ self.history_changed.emit()
+ self.timeline_changed.emit()
+
+ def undo(self):
+ if not self.can_undo():
+ return
+ command = self.undo_stack.pop()
+ self.redo_stack.append(command)
+ command.undo()
+ self.history_changed.emit()
+ self.timeline_changed.emit()
+
+ def redo(self):
+ if not self.can_redo():
+ return
+ command = self.redo_stack.pop()
+ self.undo_stack.append(command)
+ command.redo()
+ self.history_changed.emit()
+ self.timeline_changed.emit()
+
+ def can_undo(self):
+ return bool(self.undo_stack)
+
+ def can_redo(self):
+ return bool(self.redo_stack)
+
+ def undo_text(self):
+ return self.undo_stack[-1].description if self.can_undo() else ""
+
+ def redo_text(self):
+ return self.redo_stack[-1].description if self.can_redo() else ""
+
+class TimelineStateChangeCommand(UndoCommand):
+ def __init__(self, description, timeline_model, old_clips, old_v_tracks, old_a_tracks, new_clips, new_v_tracks, new_a_tracks):
+ super().__init__(description)
+ self.timeline = timeline_model
+ self.old_clips_state = old_clips
+ self.old_v_tracks = old_v_tracks
+ self.old_a_tracks = old_a_tracks
+ self.new_clips_state = new_clips
+ self.new_v_tracks = new_v_tracks
+ self.new_a_tracks = new_a_tracks
+
+ def undo(self):
+ self.timeline.clips = self.old_clips_state
+ self.timeline.num_video_tracks = self.old_v_tracks
+ self.timeline.num_audio_tracks = self.old_a_tracks
+
+ def redo(self):
+ self.timeline.clips = self.new_clips_state
+ self.timeline.num_video_tracks = self.new_v_tracks
+ self.timeline.num_audio_tracks = self.new_a_tracks
+
+
+class MoveClipsCommand(UndoCommand):
+ def __init__(self, description, timeline_model, move_data):
+ super().__init__(description)
+ self.timeline = timeline_model
+ self.move_data = move_data
+
+ def _apply_state(self, state_key_prefix):
+ for data in self.move_data:
+ clip_id = data['clip_id']
+ clip = next((c for c in self.timeline.clips if c.id == clip_id), None)
+ if clip:
+ clip.timeline_start_sec = data[f'{state_key_prefix}_start']
+ clip.track_index = data[f'{state_key_prefix}_track']
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+
+ def undo(self):
+ self._apply_state('old')
+
+ def redo(self):
+ self._apply_state('new')
\ No newline at end of file
From a21356f4f5aa6fe2a68adf19b33f1929a6bdd24a Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 10 Oct 2025 08:27:09 +1100
Subject: [PATCH 092/155] add ability to drag tracks onto ghost tracks
---
videoeditor/main.py | 177 +++++++++++++++++++++++++-------------------
1 file changed, 99 insertions(+), 78 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index a3f155e09..561244779 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -96,7 +96,6 @@ class TimelineWidget(QWidget):
remove_track = pyqtSignal(str)
operation_finished = pyqtSignal()
context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent')
- clips_moved = pyqtSignal(list)
def __init__(self, timeline_model, settings, parent=None):
super().__init__(parent)
@@ -116,8 +115,10 @@ def __init__(self, timeline_model, settings, parent=None):
self.drag_original_clip_states = {} # Store {'clip_id': (start_sec, track_index)}
self.selection_drag_start_sec = 0.0
self.drag_selection_start_values = None
+ self.drag_start_state = None
self.highlighted_track_info = None
+ self.highlighted_ghost_track_info = None
self.add_video_track_btn_rect = QRect()
self.remove_video_track_btn_rect = QRect()
self.add_audio_track_btn_rect = QRect()
@@ -292,6 +293,18 @@ def draw_tracks_and_clips(self, painter):
highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT)
painter.fillRect(highlight_rect, QColor(255, 255, 0, 40))
+ if self.highlighted_ghost_track_info:
+ track_type, track_index = self.highlighted_ghost_track_info
+ y = -1
+ if track_type == 'video':
+ y = self.TIMESCALE_HEIGHT
+ elif track_type == 'audio':
+ y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT
+
+ if y != -1:
+ highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT)
+ painter.fillRect(highlight_rect, QColor(255, 255, 0, 40))
+
# Draw clips
for clip in self.timeline.clips:
clip_rect = self.get_clip_rect(clip)
@@ -323,17 +336,30 @@ def draw_playhead(self, painter):
painter.drawLine(playhead_x, 0, playhead_x, self.height())
def y_to_track_info(self, y):
+ # Check for ghost video track
+ if self.TIMESCALE_HEIGHT <= y < self.video_tracks_y_start:
+ return ('video', self.timeline.num_video_tracks + 1)
+
+ # Check for existing video tracks
video_tracks_end_y = self.video_tracks_y_start + self.timeline.num_video_tracks * self.TRACK_HEIGHT
if self.video_tracks_y_start <= y < video_tracks_end_y:
visual_index = (y - self.video_tracks_y_start) // self.TRACK_HEIGHT
track_index = self.timeline.num_video_tracks - visual_index
return ('video', track_index)
+ # Check for existing audio tracks
audio_tracks_end_y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT
if self.audio_tracks_y_start <= y < audio_tracks_end_y:
visual_index = (y - self.audio_tracks_y_start) // self.TRACK_HEIGHT
track_index = visual_index + 1
return ('audio', track_index)
+
+ # Check for ghost audio track
+ add_audio_btn_y_start = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT
+ add_audio_btn_y_end = add_audio_btn_y_start + self.TRACK_HEIGHT
+ if add_audio_btn_y_start <= y < add_audio_btn_y_end:
+ return ('audio', self.timeline.num_audio_tracks + 1)
+
return None
def get_region_at_pos(self, pos: QPoint):
@@ -367,6 +393,7 @@ def mousePressEvent(self, event: QMouseEvent):
clip_rect = self.get_clip_rect(clip)
if clip_rect.contains(QPointF(event.pos())):
self.dragging_clip = clip
+ self.drag_start_state = self.window()._get_current_timeline_state()
self.drag_original_clip_states[clip.id] = (clip.timeline_start_sec, clip.track_index)
self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id), None)
@@ -427,24 +454,29 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.update()
elif self.dragging_clip:
self.highlighted_track_info = None
+ self.highlighted_ghost_track_info = None
new_track_info = self.y_to_track_info(event.pos().y())
original_start_sec, _ = self.drag_original_clip_states[self.dragging_clip.id]
if new_track_info:
new_track_type, new_track_index = new_track_info
+
+ is_ghost_track = (new_track_type == 'video' and new_track_index > self.timeline.num_video_tracks) or \
+ (new_track_type == 'audio' and new_track_index > self.timeline.num_audio_tracks)
+
+ if is_ghost_track:
+ self.highlighted_ghost_track_info = new_track_info
+ else:
+ self.highlighted_track_info = new_track_info
+
if new_track_type == self.dragging_clip.track_type:
self.dragging_clip.track_index = new_track_index
- if self.dragging_linked_clip:
- # For now, linked clips move to same-number track index
- self.dragging_linked_clip.track_index = new_track_index
-
delta_x = event.pos().x() - self.drag_start_pos.x()
time_delta = delta_x / self.PIXELS_PER_SECOND
new_start_time = original_start_sec + time_delta
- # Basic collision detection (can be improved)
for other_clip in self.timeline.clips:
if other_clip.id == self.dragging_clip.id: continue
if self.dragging_linked_clip and other_clip.id == self.dragging_linked_clip.id: continue
@@ -482,35 +514,27 @@ def mouseReleaseEvent(self, event: QMouseEvent):
self.dragging_playhead = False
if self.dragging_clip:
- move_data = []
- # Check main dragged clip
orig_start, orig_track = self.drag_original_clip_states[self.dragging_clip.id]
- if orig_start != self.dragging_clip.timeline_start_sec or orig_track != self.dragging_clip.track_index:
- move_data.append({
- 'clip_id': self.dragging_clip.id,
- 'old_start': orig_start, 'new_start': self.dragging_clip.timeline_start_sec,
- 'old_track': orig_track, 'new_track': self.dragging_clip.track_index
- })
- # Check linked clip
+ moved = (orig_start != self.dragging_clip.timeline_start_sec or
+ orig_track != self.dragging_clip.track_index)
+
if self.dragging_linked_clip:
orig_start_link, orig_track_link = self.drag_original_clip_states[self.dragging_linked_clip.id]
- if orig_start_link != self.dragging_linked_clip.timeline_start_sec or orig_track_link != self.dragging_linked_clip.track_index:
- move_data.append({
- 'clip_id': self.dragging_linked_clip.id,
- 'old_start': orig_start_link, 'new_start': self.dragging_linked_clip.timeline_start_sec,
- 'old_track': orig_track_link, 'new_track': self.dragging_linked_clip.track_index
- })
+ moved = moved or (orig_start_link != self.dragging_linked_clip.timeline_start_sec or
+ orig_track_link != self.dragging_linked_clip.track_index)
+
+ if moved:
+ self.window().finalize_clip_drag(self.drag_start_state)
- if move_data:
- self.clips_moved.emit(move_data)
-
self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
self.highlighted_track_info = None
+ self.highlighted_ghost_track_info = None
self.operation_finished.emit()
self.dragging_clip = None
self.dragging_linked_clip = None
self.drag_original_clip_states.clear()
+ self.drag_start_state = None
self.update()
@@ -522,6 +546,8 @@ def dragEnterEvent(self, event):
def dragLeaveEvent(self, event):
self.drag_over_active = False
+ self.highlighted_ghost_track_info = None
+ self.highlighted_track_info = None
self.update()
def dragMoveEvent(self, event):
@@ -543,6 +569,15 @@ def dragMoveEvent(self, event):
if track_info:
self.drag_over_active = True
track_type, track_index = track_info
+
+ is_ghost_track = (track_type == 'video' and track_index > self.timeline.num_video_tracks) or \
+ (track_type == 'audio' and track_index > self.timeline.num_audio_tracks)
+ if is_ghost_track:
+ self.highlighted_ghost_track_info = track_info
+ self.highlighted_track_info = None
+ else:
+ self.highlighted_ghost_track_info = None
+ self.highlighted_track_info = track_info
width = int(duration * self.PIXELS_PER_SECOND)
x = self.sec_to_x(start_sec)
@@ -564,11 +599,15 @@ def dragMoveEvent(self, event):
self.drag_over_rect = QRectF(x, video_y, width, self.TRACK_HEIGHT)
else:
self.drag_over_active = False
+ self.highlighted_ghost_track_info = None
+ self.highlighted_track_info = None
self.update()
def dropEvent(self, event):
self.drag_over_active = False
+ self.highlighted_ghost_track_info = None
+ self.highlighted_track_info = None
self.update()
mime_data = event.mimeData()
@@ -775,7 +814,7 @@ def __init__(self, project_to_load=None):
self.timeline = Timeline()
self.undo_stack = UndoStack()
self.media_pool = []
- self.media_properties = {} # To store duration, has_audio etc.
+ self.media_properties = {}
self.export_thread = None
self.export_worker = None
self.current_project_path = None
@@ -805,7 +844,6 @@ def __init__(self, project_to_load=None):
if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load))
def _get_current_timeline_state(self):
- """Helper to capture a snapshot of the timeline for undo/redo."""
return (
copy.deepcopy(self.timeline.clips),
self.timeline.num_video_tracks,
@@ -886,7 +924,6 @@ def _connect_signals(self):
self.timeline_widget.split_requested.connect(self.split_clip_at_playhead)
self.timeline_widget.delete_clip_requested.connect(self.delete_clip)
self.timeline_widget.playhead_moved.connect(self.seek_preview)
- self.timeline_widget.clips_moved.connect(self.on_clips_moved)
self.timeline_widget.split_region_requested.connect(self.on_split_region)
self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions)
self.timeline_widget.join_region_requested.connect(self.on_join_region)
@@ -911,10 +948,8 @@ def _connect_signals(self):
self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo)
def on_timeline_changed_by_undo(self):
- """Called when undo/redo modifies the timeline."""
self.prune_empty_tracks()
self.timeline_widget.update()
- # You may want to update other UI elements here as well
self.status_label.setText("Operation undone/redone.")
def update_undo_redo_actions(self):
@@ -924,13 +959,23 @@ def update_undo_redo_actions(self):
self.redo_action.setEnabled(self.undo_stack.can_redo())
self.redo_action.setText(f"Redo {self.undo_stack.redo_text()}" if self.undo_stack.can_redo() else "Redo")
- def on_clips_moved(self, move_data):
- command = MoveClipsCommand("Move Clip", self.timeline, move_data)
- # The change is already applied visually, so we push without redoing
- self.undo_stack.undo_stack.append(command)
- self.undo_stack.redo_stack.clear()
- self.undo_stack.history_changed.emit()
- self.status_label.setText("Clip moved.")
+ def finalize_clip_drag(self, old_state_tuple):
+ current_clips, _, _ = self._get_current_timeline_state()
+
+ max_v_idx = max([c.track_index for c in current_clips if c.track_type == 'video'] + [1])
+ max_a_idx = max([c.track_index for c in current_clips if c.track_type == 'audio'] + [1])
+
+ if max_v_idx > self.timeline.num_video_tracks:
+ self.timeline.num_video_tracks = max_v_idx
+
+ if max_a_idx > self.timeline.num_audio_tracks:
+ self.timeline.num_audio_tracks = max_a_idx
+
+ new_state_tuple = self._get_current_timeline_state()
+
+ command = TimelineStateChangeCommand("Move Clip", self.timeline, *old_state_tuple, *new_state_tuple)
+ command.undo()
+ self.undo_stack.push(command)
def on_add_to_timeline_at_playhead(self, file_path):
media_info = self.media_properties.get(file_path)
@@ -953,7 +998,6 @@ def on_add_to_timeline_at_playhead(self, file_path):
def prune_empty_tracks(self):
pruned_something = False
- # Prune Video Tracks
while self.timeline.num_video_tracks > 1:
highest_track_index = self.timeline.num_video_tracks
is_track_occupied = any(c for c in self.timeline.clips
@@ -963,8 +1007,7 @@ def prune_empty_tracks(self):
else:
self.timeline.num_video_tracks -= 1
pruned_something = True
-
- # Prune Audio Tracks
+
while self.timeline.num_audio_tracks > 1:
highest_track_index = self.timeline.num_audio_tracks
is_track_occupied = any(c for c in self.timeline.clips
@@ -988,7 +1031,6 @@ def add_track(self, track_type):
command = TimelineStateChangeCommand(f"Add {track_type.capitalize()} Track", self.timeline, *old_state, *new_state)
self.undo_stack.push(command)
- # Manually undo the direct change, because push() will redo() it.
command.undo()
self.undo_stack.push(command)
@@ -999,11 +1041,11 @@ def remove_track(self, track_type):
elif track_type == 'audio' and self.timeline.num_audio_tracks > 1:
self.timeline.num_audio_tracks -= 1
else:
- return # No change made
+ return
new_state = self._get_current_timeline_state()
command = TimelineStateChangeCommand(f"Remove {track_type.capitalize()} Track", self.timeline, *old_state, *new_state)
- command.undo() # Invert state because we already made the change
+ command.undo()
self.undo_stack.push(command)
@@ -1031,7 +1073,7 @@ def _create_menu_bar(self):
split_action = QAction("Split Clip at Playhead", self); split_action.triggered.connect(self.split_clip_at_playhead)
edit_menu.addAction(split_action)
- self.update_undo_redo_actions() # Set initial state
+ self.update_undo_redo_actions()
plugins_menu = menu_bar.addMenu("&Plugins")
for name, data in self.plugin_manager.plugins.items():
@@ -1048,7 +1090,7 @@ def _create_menu_bar(self):
self.windows_menu = menu_bar.addMenu("&Windows")
for key, data in self.managed_widgets.items():
- if data['widget'] is self.preview_widget: continue # Preview is part of splitter, not a dock
+ if data['widget'] is self.preview_widget: continue
action = QAction(data['name'], self, checkable=True)
if hasattr(data['widget'], 'visibilityChanged'):
action.toggled.connect(data['widget'].setVisible)
@@ -1095,10 +1137,6 @@ def _set_project_properties_from_clip(self, source_path):
return False
def get_frame_data_at_time(self, time_sec):
- """
- Plugin API: Extracts raw frame data at a specific time.
- Returns a tuple of (bytes, width, height) or (None, 0, 0) on failure.
- """
clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None)
if not clip_at_time:
return (None, 0, 0)
@@ -1201,7 +1239,7 @@ def _apply_loaded_settings(self):
for key, data in self.managed_widgets.items():
if data.get('plugin'): continue
is_visible = visibility_settings.get(key, True)
- if data['widget'] is not self.preview_widget: # preview handled by splitter state
+ if data['widget'] is not self.preview_widget:
data['widget'].setVisible(is_visible)
if data['action']: data['action'].setChecked(is_visible)
splitter_state = self.settings.get("splitter_state")
@@ -1250,7 +1288,7 @@ def open_project(self):
def _load_project_from_path(self, path):
try:
with open(path, "r") as f: project_data = json.load(f)
- self.new_project() # Clear existing state
+ self.new_project()
self.media_pool = project_data.get("media_pool", [])
for p in self.media_pool: self._add_media_to_pool(p)
@@ -1343,16 +1381,19 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, c
group_id = str(uuid.uuid4())
if video_track_index is not None:
+ if video_track_index > self.timeline.num_video_tracks:
+ self.timeline.num_video_tracks = video_track_index
video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', group_id)
self.timeline.add_clip(video_clip)
if audio_track_index is not None:
+ if audio_track_index > self.timeline.num_audio_tracks:
+ self.timeline.num_audio_tracks = audio_track_index
audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', group_id)
self.timeline.add_clip(audio_clip)
new_state = self._get_current_timeline_state()
command = TimelineStateChangeCommand("Add Clip", self.timeline, *old_state, *new_state)
- # We already performed the action, so we need to undo it before letting the stack redo it.
command.undo()
self.undo_stack.push(command)
@@ -1409,27 +1450,21 @@ def delete_clip(self, clip_to_delete):
self.prune_empty_tracks()
def _perform_complex_timeline_change(self, description, change_function):
- """Wrapper for complex operations to make them undoable."""
old_state = self._get_current_timeline_state()
-
- # Execute the provided function which will modify the timeline
change_function()
new_state = self._get_current_timeline_state()
-
- # If no change occurred, don't add to undo stack
if old_state[0] == new_state[0] and old_state[1] == new_state[1] and old_state[2] == new_state[2]:
return
command = TimelineStateChangeCommand(description, self.timeline, *old_state, *new_state)
- # The change was already made, so we need to "undo" it before letting the stack "redo" it.
command.undo()
self.undo_stack.push(command)
def on_split_region(self, region):
def action():
start_sec, end_sec = region
- clips = list(self.timeline.clips) # Work on a copy
+ clips = list(self.timeline.clips)
for clip in clips: self._split_at_time(clip, end_sec)
for clip in clips: self._split_at_time(clip, start_sec)
self.timeline_widget.clear_region(region)
@@ -1457,18 +1492,15 @@ def action():
duration_to_remove = end_sec - start_sec
if duration_to_remove <= 0.01: return
- # Split any clips that cross the region boundaries
for point in [start_sec, end_sec]:
group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
- # Remove clips fully inside the region
clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
for clip in clips_to_remove: self.timeline.clips.remove(clip)
- # Ripple delete: move subsequent clips to the left
for clip in self.timeline.clips:
if clip.timeline_start_sec >= end_sec:
clip.timeline_start_sec -= duration_to_remove
@@ -1479,26 +1511,21 @@ def action():
def on_join_all_regions(self, regions):
def action():
- # Process regions from end to start to avoid index issues
for region in sorted(regions, key=lambda r: r[0], reverse=True):
start_sec, end_sec = region
duration_to_remove = end_sec - start_sec
if duration_to_remove <= 0.01: continue
-
- # Split clips at boundaries
+
for point in [start_sec, end_sec]:
group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
-
- # Remove clips inside region
clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
for clip in clips_to_remove:
try: self.timeline.clips.remove(clip)
- except ValueError: pass # Already removed
-
- # Ripple
+ except ValueError: pass
+
for clip in self.timeline.clips:
if clip.timeline_start_sec >= end_sec:
clip.timeline_start_sec -= duration_to_remove
@@ -1513,18 +1540,15 @@ def action():
duration_to_remove = end_sec - start_sec
if duration_to_remove <= 0.01: return
- # Split any clips that cross the region boundaries
for point in [start_sec, end_sec]:
group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
- # Remove clips fully inside the region
clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
for clip in clips_to_remove: self.timeline.clips.remove(clip)
- # Ripple delete: move subsequent clips to the left
for clip in self.timeline.clips:
if clip.timeline_start_sec >= end_sec:
clip.timeline_start_sec -= duration_to_remove
@@ -1535,24 +1559,21 @@ def action():
def on_delete_all_regions(self, regions):
def action():
- # Process regions from end to start to avoid index issues
for region in sorted(regions, key=lambda r: r[0], reverse=True):
start_sec, end_sec = region
duration_to_remove = end_sec - start_sec
if duration_to_remove <= 0.01: continue
-
- # Split clips at boundaries
+
for point in [start_sec, end_sec]:
group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
-
- # Remove clips inside region
+
clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
for clip in clips_to_remove:
try: self.timeline.clips.remove(clip)
- except ValueError: pass # Already removed
+ except ValueError: pass
# Ripple
for clip in self.timeline.clips:
From 354c0d1a40126733bf074a65230148d351504355 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 10 Oct 2025 08:33:02 +1100
Subject: [PATCH 093/155] fix bug dragging ghost tracks from project media
---
videoeditor/main.py | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 561244779..34d9450c2 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -290,7 +290,7 @@ def draw_tracks_and_clips(self, painter):
y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT
if y != -1:
- highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT)
+ highlight_rect = QRect(self.HEADER_WIDTH, int(y), self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT)
painter.fillRect(highlight_rect, QColor(255, 255, 0, 40))
if self.highlighted_ghost_track_info:
@@ -302,10 +302,9 @@ def draw_tracks_and_clips(self, painter):
y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT
if y != -1:
- highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT)
+ highlight_rect = QRect(self.HEADER_WIDTH, int(y), self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT)
painter.fillRect(highlight_rect, QColor(255, 255, 0, 40))
- # Draw clips
for clip in self.timeline.clips:
clip_rect = self.get_clip_rect(clip)
base_color = QColor("#46A") if clip.track_type == 'video' else QColor("#48C")
From 97f536d3cdb05917d5d8e50deae905718df09f34 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 10 Oct 2025 01:27:39 +0200
Subject: [PATCH 094/155] more fp8 formats support
---
models/wan/any2video.py | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/models/wan/any2video.py b/models/wan/any2video.py
index 7e0592b79..d2f678060 100644
--- a/models/wan/any2video.py
+++ b/models/wan/any2video.py
@@ -59,14 +59,20 @@ def timestep_transform(t, shift=5.0, num_timesteps=1000 ):
new_t = new_t * num_timesteps
return new_t
-def preprocess_sd(sd):
+def preprocess_sd_with_dtype(dtype, sd):
new_sd = {}
prefix_list = ["model.diffusion_model"]
+ end_list = [".norm3.bias", ".norm3.weight", ".norm_q.bias", ".norm_q.weight", ".norm_k.bias", ".norm_k.weight" ]
for k,v in sd.items():
for prefix in prefix_list:
if k.startswith(prefix):
k = k[len(prefix)+1:]
- continue
+ break
+ if v.dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
+ for endfix in end_list:
+ if k.endswith(endfix):
+ v = v.to(dtype)
+ break
if not k.startswith("vae."):
new_sd[k] = v
return new_sd
@@ -143,6 +149,8 @@ def __init__(
source2 = model_def.get("source2", None)
module_source = model_def.get("module_source", None)
module_source2 = model_def.get("module_source2", None)
+ def preprocess_sd(sd):
+ return preprocess_sd_with_dtype(dtype, sd)
kwargs= { "modelClass": WanModel,"do_quantize": quantizeTransformer and not save_quantized, "defaultConfigPath": base_config_file , "ignore_unused_weights": ignore_unused_weights, "writable_tensors": False, "default_dtype": dtype, "preprocess_sd": preprocess_sd, "forcedConfigPath": forcedConfigPath, }
kwargs_light= { "modelClass": WanModel,"writable_tensors": False, "preprocess_sd": preprocess_sd , "forcedConfigPath" : base_config_file}
if module_source is not None:
From cad961803f499533c9cb6fa95424dbb425699bb7 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 10 Oct 2025 12:00:08 +1100
Subject: [PATCH 095/155] add dynamic scroll and dynamic timescale
---
videoeditor/main.py | 166 +++++++++++++++++++++++++++++++++++++-------
1 file changed, 139 insertions(+), 27 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 34d9450c2..6aee83fd6 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -77,7 +77,6 @@ def run_export(self):
self.finished.emit(f"An exception occurred during export: {e}")
class TimelineWidget(QWidget):
- PIXELS_PER_SECOND = 50
TIMESCALE_HEIGHT = 30
HEADER_WIDTH = 120
TRACK_HEIGHT = 50
@@ -97,11 +96,18 @@ class TimelineWidget(QWidget):
operation_finished = pyqtSignal()
context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent')
- def __init__(self, timeline_model, settings, parent=None):
+ def __init__(self, timeline_model, settings, project_fps, parent=None):
super().__init__(parent)
self.timeline = timeline_model
self.settings = settings
self.playhead_pos_sec = 0.0
+ self.scroll_area = None
+
+ self.pixels_per_second = 50.0
+ self.max_pixels_per_second = 1.0 # Will be updated by set_project_fps
+ self.project_fps = 25.0
+ self.set_project_fps(project_fps)
+
self.setMinimumHeight(300)
self.setMouseTracking(True)
self.setAcceptDrops(True)
@@ -132,8 +138,15 @@ def __init__(self, timeline_model, settings, parent=None):
self.drag_over_audio_rect = QRectF()
self.drag_has_audio = False
- def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.PIXELS_PER_SECOND)
- def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.PIXELS_PER_SECOND if x > self.HEADER_WIDTH else 0.0
+ def set_project_fps(self, fps):
+ self.project_fps = fps if fps > 0 else 25.0
+ # Set max zoom to be 10 pixels per frame
+ self.max_pixels_per_second = self.project_fps * 20
+ self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second)
+ self.update()
+
+ def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.pixels_per_second)
+ def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else 0.0
def paintEvent(self, event):
painter = QPainter(self)
@@ -156,7 +169,7 @@ def paintEvent(self, event):
self.draw_playhead(painter)
- total_width = self.sec_to_x(self.timeline.get_total_duration()) + 100
+ total_width = self.sec_to_x(self.timeline.get_total_duration()) + 200
total_height = self.calculate_total_height()
self.setMinimumSize(max(self.parent().width(), total_width), total_height)
@@ -225,30 +238,82 @@ def draw_headers(self, painter):
painter.drawText(self.add_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)")
painter.restore()
+
+ def _format_timecode(self, seconds):
+ if abs(seconds) < 1e-9: seconds = 0
+ sign = "-" if seconds < 0 else ""
+ seconds = abs(seconds)
+
+ h = int(seconds / 3600)
+ m = int((seconds % 3600) / 60)
+ s = seconds % 60
+
+ if h > 0: return f"{sign}{h}h:{m:02d}m"
+ if m > 0 or seconds >= 59.99: return f"{sign}{m}m:{int(round(s)):02d}s"
+
+ precision = 2 if s < 1 else 1 if s < 10 else 0
+ val = f"{s:.{precision}f}"
+ # Remove trailing .0 or .00
+ if '.' in val: val = val.rstrip('0').rstrip('.')
+ return f"{sign}{val}s"
def draw_timescale(self, painter):
painter.save()
painter.setPen(QColor("#AAA"))
painter.setFont(QFont("Arial", 8))
font_metrics = QFontMetrics(painter.font())
-
+
painter.fillRect(QRect(self.HEADER_WIDTH, 0, self.width() - self.HEADER_WIDTH, self.TIMESCALE_HEIGHT), QColor("#222"))
painter.drawLine(self.HEADER_WIDTH, self.TIMESCALE_HEIGHT - 1, self.width(), self.TIMESCALE_HEIGHT - 1)
- max_time_to_draw = self.x_to_sec(self.width())
- major_interval_sec = 5
- minor_interval_sec = 1
+ frame_dur = 1.0 / self.project_fps
+ intervals = [
+ frame_dur, 2*frame_dur, 5*frame_dur, 10*frame_dur,
+ 0.1, 0.2, 0.5, 1, 2, 5, 10, 15, 30,
+ 60, 120, 300, 600, 900, 1800,
+ 3600, 2*3600, 5*3600, 10*3600
+ ]
+
+ min_pixel_dist = 70
+ major_interval = next((i for i in intervals if i * self.pixels_per_second > min_pixel_dist), intervals[-1])
- for t_sec in range(int(max_time_to_draw) + 2):
+ minor_interval = 0
+ for divisor in [5, 4, 2]:
+ if (major_interval / divisor) * self.pixels_per_second > 10:
+ minor_interval = major_interval / divisor
+ break
+
+ start_sec = self.x_to_sec(self.HEADER_WIDTH)
+ end_sec = self.x_to_sec(self.width())
+
+ def draw_ticks(interval, height):
+ if interval <= 1e-6: return
+ start_tick_num = int(start_sec / interval)
+ end_tick_num = int(end_sec / interval) + 1
+ for i in range(start_tick_num, end_tick_num + 1):
+ t_sec = i * interval
+ x = self.sec_to_x(t_sec)
+ if x > self.width(): break
+ if x >= self.HEADER_WIDTH:
+ painter.drawLine(x, self.TIMESCALE_HEIGHT - height, x, self.TIMESCALE_HEIGHT)
+
+ if frame_dur * self.pixels_per_second > 4:
+ draw_ticks(frame_dur, 3)
+ if minor_interval > 0:
+ draw_ticks(minor_interval, 6)
+
+ start_major_tick = int(start_sec / major_interval)
+ end_major_tick = int(end_sec / major_interval) + 1
+ for i in range(start_major_tick, end_major_tick + 1):
+ t_sec = i * major_interval
x = self.sec_to_x(t_sec)
-
- if t_sec % major_interval_sec == 0:
- painter.drawLine(x, self.TIMESCALE_HEIGHT - 10, x, self.TIMESCALE_HEIGHT - 1)
- label = f"{t_sec}s"
+ if x > self.width() + 50: break
+ if x >= self.HEADER_WIDTH - 50:
+ painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT)
+ label = self._format_timecode(t_sec)
label_width = font_metrics.horizontalAdvance(label)
- painter.drawText(x - label_width // 2, self.TIMESCALE_HEIGHT - 12, label)
- elif t_sec % minor_interval_sec == 0:
- painter.drawLine(x, self.TIMESCALE_HEIGHT - 5, x, self.TIMESCALE_HEIGHT - 1)
+ painter.drawText(x - label_width // 2, self.TIMESCALE_HEIGHT - 14, label)
+
painter.restore()
def get_clip_rect(self, clip):
@@ -260,7 +325,7 @@ def get_clip_rect(self, clip):
y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT
x = self.sec_to_x(clip.timeline_start_sec)
- w = int(clip.duration_sec * self.PIXELS_PER_SECOND)
+ w = int(clip.duration_sec * self.pixels_per_second)
clip_height = self.TRACK_HEIGHT - 10
y += (self.TRACK_HEIGHT - clip_height) / 2
return QRectF(x, y, w, clip_height)
@@ -323,7 +388,7 @@ def draw_tracks_and_clips(self, painter):
def draw_selections(self, painter):
for start_sec, end_sec in self.selection_regions:
x = self.sec_to_x(start_sec)
- w = int((end_sec - start_sec) * self.PIXELS_PER_SECOND)
+ w = int((end_sec - start_sec) * self.pixels_per_second)
selection_rect = QRectF(x, self.TIMESCALE_HEIGHT, w, self.height() - self.TIMESCALE_HEIGHT)
painter.fillRect(selection_rect, QColor(100, 100, 255, 80))
painter.setPen(QColor(150, 150, 255, 150))
@@ -371,6 +436,46 @@ def get_region_at_pos(self, pos: QPoint):
return region
return None
+ def wheelEvent(self, event: QMouseEvent):
+ if not self.scroll_area or event.position().x() < self.HEADER_WIDTH:
+ event.ignore()
+ return
+
+ scrollbar = self.scroll_area.horizontalScrollBar()
+ mouse_x_abs = event.position().x()
+
+ time_at_cursor = self.x_to_sec(mouse_x_abs)
+
+ delta = event.angleDelta().y()
+ zoom_factor = 1.15
+ old_pps = self.pixels_per_second
+
+ if delta > 0: new_pps = old_pps * zoom_factor
+ else: new_pps = old_pps / zoom_factor
+
+ min_pps = 1 / (3600 * 10) # Min zoom of 1px per 10 hours
+ new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second))
+
+ if abs(new_pps - old_pps) < 1e-9:
+ return
+
+ self.pixels_per_second = new_pps
+ # This will trigger a paintEvent, which resizes the widget.
+ # The scroll area will update its scrollbar ranges in response.
+ self.update()
+
+ # The new absolute x-coordinate for the time under the cursor
+ new_mouse_x_abs = self.sec_to_x(time_at_cursor)
+
+ # The amount the content "shifted" at the cursor's location
+ shift_amount = new_mouse_x_abs - mouse_x_abs
+
+ # Adjust the scrollbar by this shift amount to keep the content under the cursor
+ new_scroll_value = scrollbar.value() + shift_amount
+ scrollbar.setValue(int(new_scroll_value))
+
+ event.accept()
+
def mousePressEvent(self, event: QMouseEvent):
if event.button() == Qt.MouseButton.LeftButton:
if event.pos().x() < self.HEADER_WIDTH:
@@ -435,7 +540,7 @@ def mouseMoveEvent(self, event: QMouseEvent):
if self.dragging_selection_region:
delta_x = event.pos().x() - self.drag_start_pos.x()
- time_delta = delta_x / self.PIXELS_PER_SECOND
+ time_delta = delta_x / self.pixels_per_second
original_start, original_end = self.drag_selection_start_values
duration = original_end - original_start
@@ -473,7 +578,7 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.dragging_clip.track_index = new_track_index
delta_x = event.pos().x() - self.drag_start_pos.x()
- time_delta = delta_x / self.PIXELS_PER_SECOND
+ time_delta = delta_x / self.pixels_per_second
new_start_time = original_start_sec + time_delta
for other_clip in self.timeline.clips:
@@ -504,7 +609,7 @@ def mouseReleaseEvent(self, event: QMouseEvent):
self.creating_selection_region = False
if self.selection_regions:
start, end = self.selection_regions[-1]
- if (end - start) * self.PIXELS_PER_SECOND < 2:
+ if (end - start) * self.pixels_per_second < 2:
self.selection_regions.pop()
if self.dragging_selection_region:
@@ -578,7 +683,7 @@ def dragMoveEvent(self, event):
self.highlighted_ghost_track_info = None
self.highlighted_track_info = track_info
- width = int(duration * self.PIXELS_PER_SECOND)
+ width = int(duration * self.pixels_per_second)
x = self.sec_to_x(start_sec)
if track_type == 'video':
@@ -864,9 +969,11 @@ def _setup_ui(self):
self.preview_widget.setStyleSheet("background-color: black; color: white;")
self.splitter.addWidget(self.preview_widget)
- self.timeline_widget = TimelineWidget(self.timeline, self.settings, self)
+ self.timeline_widget = TimelineWidget(self.timeline, self.settings, self.project_fps, self)
self.timeline_scroll_area = QScrollArea()
- self.timeline_scroll_area.setWidgetResizable(True)
+ self.timeline_widget.scroll_area = self.timeline_scroll_area
+ self.timeline_scroll_area.setWidgetResizable(False)
+ self.timeline_widget.setMinimumWidth(2000)
self.timeline_scroll_area.setWidget(self.timeline_widget)
self.timeline_scroll_area.setFrameShape(QFrame.Shape.NoFrame)
self.timeline_scroll_area.setMinimumHeight(250)
@@ -1129,7 +1236,9 @@ def _set_project_properties_from_clip(self, source_path):
self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height'])
if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0':
num, den = map(int, video_stream['r_frame_rate'].split('/'))
- if den > 0: self.project_fps = num / den
+ if den > 0:
+ self.project_fps = num / den
+ self.timeline_widget.set_project_fps(self.project_fps)
print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS")
return True
except Exception as e: print(f"Could not probe for project properties: {e}")
@@ -1260,7 +1369,10 @@ def open_settings_dialog(self):
def new_project(self):
self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1
self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list()
- self.current_project_path = None; self.stop_playback(); self.timeline_widget.update()
+ self.current_project_path = None; self.stop_playback()
+ self.project_fps = 25.0
+ self.timeline_widget.set_project_fps(self.project_fps)
+ self.timeline_widget.update()
self.undo_stack = UndoStack()
self.undo_stack.history_changed.connect(self.update_undo_redo_actions)
self.update_undo_redo_actions()
From 8b3552f9d85f9ecb2eaf243310762126bcf3277f Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 10 Oct 2025 12:11:29 +0200
Subject: [PATCH 096/155] fixed bugs and added Depth control for Qwen
---
models/qwen/qwen_handler.py | 2 +-
models/wan/wan_handler.py | 2 +-
wgp.py | 6 ++++--
3 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py
index 2a3c8487a..0505a5768 100644
--- a/models/qwen/qwen_handler.py
+++ b/models/qwen/qwen_handler.py
@@ -46,7 +46,7 @@ def query_model_def(base_model_type, model_def):
if base_model_type in ["qwen_image_edit_plus_20B"]:
extra_model_def["guide_preprocessing"] = {
- "selection": ["", "PV", "SV", "CV"],
+ "selection": ["", "PV", "DV", "SV", "CV"],
}
extra_model_def["mask_preprocessing"] = {
diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py
index bb6745189..f70e376b4 100644
--- a/models/wan/wan_handler.py
+++ b/models/wan/wan_handler.py
@@ -396,7 +396,7 @@ def query_model_def(base_model_type, model_def):
"visible" : False,
}
- if vace_class or base_model_type in ["animate"]:
+ if vace_class or base_model_type in ["animate", "t2v", "t2v_2_2"] :
image_prompt_types_allowed = "TVL"
elif base_model_type in ["infinitetalk"]:
image_prompt_types_allowed = "TSVL"
diff --git a/wgp.py b/wgp.py
index 3bc3e54d2..43e61bad5 100644
--- a/wgp.py
+++ b/wgp.py
@@ -2646,8 +2646,9 @@ def computeList(filename):
preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True)
for url in preload_URLs:
- filename = fl.get_download_location(url.split("/")[-1])
- if not os.path.isfile(filename ):
+ filename = fl.locate_file(os.path.basename(url))
+ if filename is None:
+ filename = fl.get_download_location(os.path.basename(url))
if not url.startswith("http"):
raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the model definition file.")
try:
@@ -5206,6 +5207,7 @@ def remove_temp_filenames(temp_filenames_list):
if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1))
if video_guide is not None:
preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame)
+ preview_frame_no = min(src_video.shape[1] -1, preview_frame_no)
refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no)
if src_video2 is not None:
refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)]
From e3e0c0325513cc6dba3fb6667486cbf25fba2011 Mon Sep 17 00:00:00 2001
From: DeepBeepMeep
Date: Fri, 10 Oct 2025 13:02:43 +0200
Subject: [PATCH 097/155] fixed ref images ignored with multiple samples
---
models/wan/modules/model.py | 1 -
wgp.py | 5 +++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py
index 8c33a17a7..4f6d63e36 100644
--- a/models/wan/modules/model.py
+++ b/models/wan/modules/model.py
@@ -1339,7 +1339,6 @@ def forward(
):
# patch_dtype = self.patch_embedding.weight.dtype
modulation_dtype = self.time_projection[1].weight.dtype
-
if self.model_type == 'i2v':
assert clip_fea is not None and y is not None
# params
diff --git a/wgp.py b/wgp.py
index 43e61bad5..74af14e4a 100644
--- a/wgp.py
+++ b/wgp.py
@@ -4846,6 +4846,7 @@ def remove_temp_filenames(temp_filenames_list):
trans.cache = skip_steps_cache
if trans2 is not None: trans2.cache = skip_steps_cache
face_arc_embeds = None
+ src_ref_images = src_ref_masks = None
output_new_audio_data = None
output_new_audio_filepath = None
original_audio_guide = audio_guide
@@ -4940,7 +4941,7 @@ def remove_temp_filenames(temp_filenames_list):
first_window_video_length = current_video_length
original_prompts = prompts.copy()
- gen["sliding_window"] = sliding_window
+ gen["sliding_window"] = sliding_window
while not abort:
extra_generation += gen.get("extra_orders",0)
gen["extra_orders"] = 0
@@ -4950,7 +4951,7 @@ def remove_temp_filenames(temp_filenames_list):
if repeat_no >= total_generation: break
repeat_no +=1
gen["repeat_no"] = repeat_no
- src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None
+ src_video = src_video2 = src_mask = src_mask2 = src_faces = sparse_video_image = None
prefix_video = pre_video_frame = None
source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window
source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before )
From eacb49814b7b13845598377c41e785489670fbd9 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 10 Oct 2025 23:21:10 +1100
Subject: [PATCH 098/155] fix ai joiner to start polling for outputs during
manual generate, stopped selector sliding around
---
videoeditor/main.py | 51 +++++++++++----------
videoeditor/plugins/ai_frame_joiner/main.py | 12 ++---
2 files changed, 33 insertions(+), 30 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 6aee83fd6..0c08e7b9a 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -153,7 +153,11 @@ def paintEvent(self, event):
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
painter.fillRect(self.rect(), QColor("#333"))
- self.draw_headers(painter)
+ h_offset = 0
+ if self.scroll_area and self.scroll_area.horizontalScrollBar():
+ h_offset = self.scroll_area.horizontalScrollBar().value()
+
+ self.draw_headers(painter, h_offset)
self.draw_timescale(painter)
self.draw_tracks_and_clips(painter)
self.draw_selections(painter)
@@ -178,7 +182,7 @@ def calculate_total_height(self):
audio_tracks_height = (self.timeline.num_audio_tracks + 1) * self.TRACK_HEIGHT
return self.TIMESCALE_HEIGHT + video_tracks_height + self.AUDIO_TRACKS_SEPARATOR_Y + audio_tracks_height + 20
- def draw_headers(self, painter):
+ def draw_headers(self, painter, h_offset):
painter.save()
painter.setPen(QColor("#AAA"))
header_font = QFont("Arial", 9, QFont.Weight.Bold)
@@ -187,9 +191,9 @@ def draw_headers(self, painter):
y_cursor = self.TIMESCALE_HEIGHT
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect, QColor("#3a3a3a"))
- painter.drawRect(rect)
- self.add_video_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
+ painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a"))
+ painter.drawRect(rect.translated(h_offset, 0))
+ self.add_video_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
painter.setFont(button_font)
painter.fillRect(self.add_video_track_btn_rect, QColor("#454"))
painter.drawText(self.add_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)")
@@ -199,13 +203,13 @@ def draw_headers(self, painter):
for i in range(self.timeline.num_video_tracks):
track_number = self.timeline.num_video_tracks - i
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect, QColor("#444"))
- painter.drawRect(rect)
+ painter.fillRect(rect.translated(h_offset, 0), QColor("#444"))
+ painter.drawRect(rect.translated(h_offset, 0))
painter.setFont(header_font)
- painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Video {track_number}")
+ painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Video {track_number}")
if track_number == self.timeline.num_video_tracks and self.timeline.num_video_tracks > 1:
- self.remove_video_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20)
+ self.remove_video_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20)
painter.setFont(button_font)
painter.fillRect(self.remove_video_track_btn_rect, QColor("#833"))
painter.drawText(self.remove_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-")
@@ -217,22 +221,22 @@ def draw_headers(self, painter):
for i in range(self.timeline.num_audio_tracks):
track_number = i + 1
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect, QColor("#444"))
- painter.drawRect(rect)
+ painter.fillRect(rect.translated(h_offset, 0), QColor("#444"))
+ painter.drawRect(rect.translated(h_offset, 0))
painter.setFont(header_font)
- painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}")
+ painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}")
if track_number == self.timeline.num_audio_tracks and self.timeline.num_audio_tracks > 1:
- self.remove_audio_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20)
+ self.remove_audio_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20)
painter.setFont(button_font)
painter.fillRect(self.remove_audio_track_btn_rect, QColor("#833"))
painter.drawText(self.remove_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-")
y_cursor += self.TRACK_HEIGHT
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect, QColor("#3a3a3a"))
- painter.drawRect(rect)
- self.add_audio_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
+ painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a"))
+ painter.drawRect(rect.translated(h_offset, 0))
+ self.add_audio_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
painter.setFont(button_font)
painter.fillRect(self.add_audio_track_btn_rect, QColor("#454"))
painter.drawText(self.add_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)")
@@ -477,15 +481,14 @@ def wheelEvent(self, event: QMouseEvent):
event.accept()
def mousePressEvent(self, event: QMouseEvent):
- if event.button() == Qt.MouseButton.LeftButton:
- if event.pos().x() < self.HEADER_WIDTH:
- # Click in headers area
- if self.add_video_track_btn_rect.contains(event.pos()): self.add_track.emit('video')
- elif self.remove_video_track_btn_rect.contains(event.pos()): self.remove_track.emit('video')
- elif self.add_audio_track_btn_rect.contains(event.pos()): self.add_track.emit('audio')
- elif self.remove_audio_track_btn_rect.contains(event.pos()): self.remove_track.emit('audio')
- return
+ if event.pos().x() < self.HEADER_WIDTH + self.scroll_area.horizontalScrollBar().value():
+ if self.add_video_track_btn_rect.contains(event.pos()): self.add_track.emit('video')
+ elif self.remove_video_track_btn_rect.contains(event.pos()): self.remove_track.emit('video')
+ elif self.add_audio_track_btn_rect.contains(event.pos()): self.add_track.emit('audio')
+ elif self.remove_audio_track_btn_rect.contains(event.pos()): self.remove_track.emit('audio')
+ return
+ if event.button() == Qt.MouseButton.LeftButton:
self.dragging_clip = None
self.dragging_linked_clip = None
self.dragging_playhead = False
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index 56b94efe6..a7c7cd648 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -94,7 +94,6 @@ def set_previews(self, start_pixmap, end_pixmap):
def start_polling(self):
self.poll_timer.start()
- self.check_server_status()
def stop_polling(self):
self.poll_timer.stop()
@@ -111,9 +110,11 @@ def check_server_status(self):
try:
requests.get(f"{API_BASE_URL}/api/latest_output", timeout=1)
self.status_label.setText("Status: Connected to WanGP server.")
+ return True
except requests.exceptions.ConnectionError:
self.status_label.setText("Status: Error - Cannot connect to WanGP server.")
QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.")
+ return False
def generate(self):
@@ -139,9 +140,9 @@ def generate(self):
if response.status_code == 200:
if payload['start_generation']:
self.status_label.setText("Status: Parameters set. Generation sent. Polling...")
- self.start_polling()
else:
- self.status_label.setText("Status: Parameters set. Waiting for manual start.")
+ self.status_label.setText("Status: Parameters set. Polling for manually started generation...")
+ self.start_polling()
else:
self.handle_api_error(response, "setting parameters")
except requests.exceptions.RequestException as e:
@@ -216,9 +217,8 @@ def setup_generator_for_region(self, region):
self._reset_state()
self.active_region = region
start_sec, end_sec = region
-
- self.client_widget.check_server_status()
-
+ if not self.client_widget.check_server_status():
+ return
start_data, w, h = self.app.get_frame_data_at_time(start_sec)
end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
From 14fbfaf37b1d221407219aab14c75c0e1110d1c9 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 00:10:08 +1100
Subject: [PATCH 099/155] fix FPS conversions
---
main.py | 24 +++++----------------
videoeditor/plugins/ai_frame_joiner/main.py | 8 ++++---
2 files changed, 10 insertions(+), 22 deletions(-)
diff --git a/main.py b/main.py
index 06292b315..0dff699e6 100644
--- a/main.py
+++ b/main.py
@@ -355,7 +355,6 @@ def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_
if model_type:
self._api_set_model(model_type)
- # 2. Set frame inputs
if start_frame:
self.widgets['mode_s'].setChecked(True)
self.widgets['image_start'].setText(start_frame)
@@ -363,37 +362,24 @@ def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_
if end_frame:
self.widgets['image_end_checkbox'].setChecked(True)
self.widgets['image_end'].setText(end_frame)
-
- # 3. Calculate video length in frames based on duration and model FPS
+
if duration_sec is not None:
try:
duration = float(duration_sec)
-
- # Get base FPS by parsing the "Force FPS" dropdown's default text (e.g., "Model Default (16 fps)")
- base_fps = 16 # Fallback
+ base_fps = 16
fps_text = self.widgets['force_fps'].itemText(0)
match = re.search(r'\((\d+)\s*fps\)', fps_text)
if match:
base_fps = int(match.group(1))
- # Temporal upsampling creates more frames in post-processing, so we must account for it here.
- upsample_setting = self.widgets['temporal_upsampling'].currentData()
- multiplier = 1.0
- if upsample_setting == "rife2":
- multiplier = 2.0
- elif upsample_setting == "rife4":
- multiplier = 4.0
-
- # The number of frames the model needs to generate
- video_length_frames = int(duration * base_fps * multiplier)
+ video_length_frames = int(duration * base_fps)
self.widgets['video_length'].setValue(video_length_frames)
- print(f"API: Calculated video length: {video_length_frames} frames for {duration:.2f}s @ {base_fps*multiplier:.0f} effective FPS.")
+ print(f"API: Calculated video length: {video_length_frames} frames for {duration:.2f}s @ {base_fps} effective FPS.")
except (ValueError, TypeError) as e:
print(f"API Error: Invalid duration_sec '{duration_sec}': {e}")
-
- # 4. Conditionally start generation
+
if start_generation:
self.generate_btn.click()
print("API: Generation started.")
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index a7c7cd648..1b20a3b36 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -272,10 +272,12 @@ def insert_generated_clip(self, video_path):
return
start_sec, end_sec = self.active_region
- duration = end_sec - start_sec
self.app.status_label.setText(f"Inserting AI clip: {os.path.basename(video_path)}")
try:
+ probe = ffmpeg.probe(video_path)
+ actual_duration = float(probe['format']['duration'])
+
for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
@@ -291,7 +293,7 @@ def insert_generated_clip(self, video_path):
source_path=video_path,
timeline_start_sec=start_sec,
clip_start_sec=0,
- duration_sec=duration,
+ duration_sec=actual_duration,
video_track_index=1,
audio_track_index=None
)
@@ -301,7 +303,7 @@ def insert_generated_clip(self, video_path):
self.app.prune_empty_tracks()
except Exception as e:
- error_message = f"Error during clip insertion: {e}"
+ error_message = f"Error during clip insertion/probing: {e}"
self.app.status_label.setText(error_message)
self.client_widget.status_label.setText("Status: Failed to insert clip.")
print(error_message)
From e776c35a7b97b7a6ac93eda633f59c31f2eb2c4e Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 00:15:46 +1100
Subject: [PATCH 100/155] fix missing import
---
videoeditor/plugins/ai_frame_joiner/main.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index 1b20a3b36..c81b98fa3 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -4,6 +4,7 @@
import shutil
import requests
from pathlib import Path
+import ffmpeg
from PyQt6.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QFormLayout,
From 77ca09d3eec6efe522244e0876626c1aa1f13a2f Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 00:50:18 +1100
Subject: [PATCH 101/155] add ability to generate into new track, move statuses
to main editor
---
videoeditor/plugins/ai_frame_joiner/main.py | 117 +++++++++++---------
1 file changed, 66 insertions(+), 51 deletions(-)
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index c81b98fa3..3c8ef6157 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -21,6 +21,7 @@
class WgpClientWidget(QWidget):
generation_complete = pyqtSignal(str)
+ status_updated = pyqtSignal(str)
def __init__(self):
super().__init__()
@@ -60,10 +61,6 @@ def __init__(self):
self.autostart_checkbox.setChecked(True)
self.generate_button = QPushButton("Generate")
-
- self.status_label = QLabel("Status: Idle")
- self.output_label = QLabel("Latest Output File: None")
- self.output_label.setWordWrap(True)
layout.addLayout(previews_layout)
form_layout.addRow("Model Type:", self.model_input)
@@ -71,8 +68,7 @@ def __init__(self):
form_layout.addRow(self.generate_button)
layout.addLayout(form_layout)
- layout.addWidget(self.status_label)
- layout.addWidget(self.output_label)
+ layout.addStretch()
self.generate_button.clicked.connect(self.generate)
@@ -104,16 +100,16 @@ def handle_api_error(self, response, action="performing action"):
error_msg = response.json().get("error", "Unknown error")
except requests.exceptions.JSONDecodeError:
error_msg = response.text
- self.status_label.setText(f"Status: Error {action}: {error_msg}")
+ self.status_updated.emit(f"AI Joiner Error: {error_msg}")
QMessageBox.warning(self, "API Error", f"Failed while {action}.\n\nServer response:\n{error_msg}")
def check_server_status(self):
try:
requests.get(f"{API_BASE_URL}/api/latest_output", timeout=1)
- self.status_label.setText("Status: Connected to WanGP server.")
+ self.status_updated.emit("AI Joiner: Connected to WanGP server.")
return True
except requests.exceptions.ConnectionError:
- self.status_label.setText("Status: Error - Cannot connect to WanGP server.")
+ self.status_updated.emit("AI Joiner Error: Cannot connect to WanGP server.")
QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.")
return False
@@ -135,19 +131,19 @@ def generate(self):
payload['start_generation'] = self.autostart_checkbox.isChecked()
- self.status_label.setText("Status: Sending parameters...")
+ self.status_updated.emit("AI Joiner: Sending parameters...")
try:
response = requests.post(f"{API_BASE_URL}/api/generate", json=payload)
if response.status_code == 200:
if payload['start_generation']:
- self.status_label.setText("Status: Parameters set. Generation sent. Polling...")
+ self.status_updated.emit("AI Joiner: Generation sent to server. Polling for output...")
else:
- self.status_label.setText("Status: Parameters set. Polling for manually started generation...")
+ self.status_updated.emit("AI Joiner: Parameters set. Polling for manually started generation...")
self.start_polling()
else:
self.handle_api_error(response, "setting parameters")
except requests.exceptions.RequestException as e:
- self.status_label.setText(f"Status: Connection error: {e}")
+ self.status_updated.emit(f"AI Joiner Error: Connection error: {e}")
def poll_for_output(self):
try:
@@ -159,15 +155,12 @@ def poll_for_output(self):
if latest_path and latest_path != self.last_known_output:
self.stop_polling()
self.last_known_output = latest_path
- self.output_label.setText(f"Latest Output File:\n{latest_path}")
- self.status_label.setText("Status: New output received! Inserting clip...")
+ self.status_updated.emit(f"AI Joiner: New output received! Inserting clip...")
self.generation_complete.emit(latest_path)
else:
- if "Error" not in self.status_label.text() and "waiting" not in self.status_label.text().lower():
- self.status_label.setText("Status: Polling for output...")
+ self.status_updated.emit("AI Joiner: Polling for output...")
except requests.exceptions.RequestException:
- if "Error" not in self.status_label.text():
- self.status_label.setText("Status: Polling... (Connection issue)")
+ self.status_updated.emit("AI Joiner: Polling... (Connection issue)")
class Plugin(VideoEditorPlugin):
def initialize(self):
@@ -177,7 +170,9 @@ def initialize(self):
self.dock_widget = None
self.active_region = None
self.temp_dir = None
+ self.insert_on_new_track = False
self.client_widget.generation_complete.connect(self.insert_generated_clip)
+ self.client_widget.status_updated.connect(self.update_main_status)
def enable(self):
if not self.dock_widget:
@@ -195,6 +190,9 @@ def disable(self):
self._cleanup_temp_dir()
self.client_widget.stop_polling()
+ def update_main_status(self, message):
+ self.app.status_label.setText(message)
+
def _cleanup_temp_dir(self):
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
@@ -202,9 +200,9 @@ def _cleanup_temp_dir(self):
def _reset_state(self):
self.active_region = None
+ self.insert_on_new_track = False
self._cleanup_temp_dir()
- self.client_widget.status_label.setText("Status: Idle")
- self.client_widget.output_label.setText("Latest Output File: None")
+ self.update_main_status("AI Joiner: Idle")
self.client_widget.set_previews(None, None)
def on_timeline_context_menu(self, menu, event):
@@ -212,14 +210,20 @@ def on_timeline_context_menu(self, menu, event):
if region:
menu.addSeparator()
action = menu.addAction("Join Frames With AI")
- action.triggered.connect(lambda: self.setup_generator_for_region(region))
+ action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
+
+ action_new_track = menu.addAction("Join Frames With AI (New Track)")
+ action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
- def setup_generator_for_region(self, region):
+ def setup_generator_for_region(self, region, on_new_track=False):
self._reset_state()
self.active_region = region
+ self.insert_on_new_track = on_new_track
+
start_sec, end_sec = region
if not self.client_widget.check_server_status():
return
+
start_data, w, h = self.app.get_frame_data_at_time(start_sec)
end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
@@ -258,55 +262,66 @@ def setup_generator_for_region(self, region):
self.client_widget.start_frame_input.setText(start_img_path)
self.client_widget.end_frame_input.setText(end_img_path)
- self.client_widget.status_label.setText(f"Status: Ready for region {start_sec:.2f}s - {end_sec:.2f}s")
+ self.update_main_status(f"AI Joiner: Ready for region {start_sec:.2f}s - {end_sec:.2f}s")
self.dock_widget.show()
self.dock_widget.raise_()
def insert_generated_clip(self, video_path):
if not self.active_region:
- self.client_widget.status_label.setText("Status: Error - No active region to insert into.")
+ self.update_main_status("AI Joiner Error: No active region to insert into.")
return
if not os.path.exists(video_path):
- self.client_widget.status_label.setText(f"Status: Error - Output file not found: {video_path}")
+ self.update_main_status(f"AI Joiner Error: Output file not found: {video_path}")
return
start_sec, end_sec = self.active_region
- self.app.status_label.setText(f"Inserting AI clip: {os.path.basename(video_path)}")
+ self.update_main_status(f"AI Joiner: Inserting clip {os.path.basename(video_path)}")
try:
probe = ffmpeg.probe(video_path)
actual_duration = float(probe['format']['duration'])
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
+ if self.insert_on_new_track:
+ self.app.add_track('video')
+ new_track_index = self.app.timeline.num_video_tracks
+ self.app._add_clip_to_timeline(
+ source_path=video_path,
+ timeline_start_sec=start_sec,
+ clip_start_sec=0,
+ duration_sec=actual_duration,
+ video_track_index=new_track_index,
+ audio_track_index=None
+ )
+ self.update_main_status("AI clip inserted successfully on new track.")
+ else:
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
+
+ clips_to_remove = [
+ c for c in self.app.timeline.clips
+ if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec
+ ]
+ for clip in clips_to_remove:
+ if clip in self.app.timeline.clips:
+ self.app.timeline.clips.remove(clip)
+
+ self.app._add_clip_to_timeline(
+ source_path=video_path,
+ timeline_start_sec=start_sec,
+ clip_start_sec=0,
+ duration_sec=actual_duration,
+ video_track_index=1,
+ audio_track_index=None
+ )
+ self.update_main_status("AI clip inserted successfully.")
- clips_to_remove = [
- c for c in self.app.timeline.clips
- if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec
- ]
- for clip in clips_to_remove:
- if clip in self.app.timeline.clips:
- self.app.timeline.clips.remove(clip)
-
- self.app._add_clip_to_timeline(
- source_path=video_path,
- timeline_start_sec=start_sec,
- clip_start_sec=0,
- duration_sec=actual_duration,
- video_track_index=1,
- audio_track_index=None
- )
-
- self.app.status_label.setText("AI clip inserted successfully.")
- self.client_widget.status_label.setText("Status: Success! Clip inserted.")
self.app.prune_empty_tracks()
except Exception as e:
- error_message = f"Error during clip insertion/probing: {e}"
- self.app.status_label.setText(error_message)
- self.client_widget.status_label.setText("Status: Failed to insert clip.")
+ error_message = f"AI Joiner Error during clip insertion/probing: {e}"
+ self.update_main_status(error_message)
print(error_message)
finally:
From 3438e8fd833392c0b6e13770685ac511f1a45b94 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 01:30:07 +1100
Subject: [PATCH 102/155] add image and audio media types
---
videoeditor/main.py | 291 ++++++++++++++------
videoeditor/plugins/ai_frame_joiner/main.py | 2 +
2 files changed, 209 insertions(+), 84 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 0c08e7b9a..49ac27b13 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -19,7 +19,7 @@
from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand
class TimelineClip:
- def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, group_id):
+ def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, media_type, group_id):
self.id = str(uuid.uuid4())
self.source_path = source_path
self.timeline_start_sec = timeline_start_sec
@@ -27,6 +27,7 @@ def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec
self.duration_sec = duration_sec
self.track_index = track_index
self.track_type = track_type
+ self.media_type = media_type
self.group_id = group_id
@property
@@ -136,7 +137,6 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.drag_over_active = False
self.drag_over_rect = QRectF()
self.drag_over_audio_rect = QRectF()
- self.drag_has_audio = False
def set_project_fps(self, fps):
self.project_fps = fps if fps > 0 else 25.0
@@ -163,12 +163,12 @@ def paintEvent(self, event):
self.draw_selections(painter)
if self.drag_over_active:
- painter.fillRect(self.drag_over_rect, QColor(0, 255, 0, 80))
- if self.drag_has_audio:
- painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80))
painter.setPen(QColor(0, 255, 0, 150))
- painter.drawRect(self.drag_over_rect)
- if self.drag_has_audio:
+ if not self.drag_over_rect.isNull():
+ painter.fillRect(self.drag_over_rect, QColor(0, 255, 0, 80))
+ painter.drawRect(self.drag_over_rect)
+ if not self.drag_over_audio_rect.isNull():
+ painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80))
painter.drawRect(self.drag_over_audio_rect)
self.draw_playhead(painter)
@@ -376,7 +376,12 @@ def draw_tracks_and_clips(self, painter):
for clip in self.timeline.clips:
clip_rect = self.get_clip_rect(clip)
- base_color = QColor("#46A") if clip.track_type == 'video' else QColor("#48C")
+ base_color = QColor("#46A") # Default video
+ if clip.media_type == 'image':
+ base_color = QColor("#4A6") # Greenish for images
+ elif clip.track_type == 'audio':
+ base_color = QColor("#48C") # Bluish for audio
+
color = QColor("#5A9") if self.dragging_clip and self.dragging_clip.id == clip.id else base_color
painter.fillRect(clip_rect, color)
painter.setPen(QPen(QColor("#FFF"), 1))
@@ -667,47 +672,55 @@ def dragMoveEvent(self, event):
json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8'))
duration = json_data['duration']
- self.drag_has_audio = json_data['has_audio']
+ media_type = json_data['media_type']
+ has_audio = json_data['has_audio']
pos = event.position()
- track_info = self.y_to_track_info(pos.y())
start_sec = self.x_to_sec(pos.x())
+ track_info = self.y_to_track_info(pos.y())
+
+ self.drag_over_rect = QRectF()
+ self.drag_over_audio_rect = QRectF()
+ self.drag_over_active = False
+ self.highlighted_ghost_track_info = None
+ self.highlighted_track_info = None
if track_info:
self.drag_over_active = True
track_type, track_index = track_info
-
+
is_ghost_track = (track_type == 'video' and track_index > self.timeline.num_video_tracks) or \
(track_type == 'audio' and track_index > self.timeline.num_audio_tracks)
if is_ghost_track:
self.highlighted_ghost_track_info = track_info
- self.highlighted_track_info = None
else:
- self.highlighted_ghost_track_info = None
self.highlighted_track_info = track_info
width = int(duration * self.pixels_per_second)
x = self.sec_to_x(start_sec)
- if track_type == 'video':
- visual_index = self.timeline.num_video_tracks - track_index
- y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT
- self.drag_over_rect = QRectF(x, y, width, self.TRACK_HEIGHT)
- if self.drag_has_audio:
- audio_y = self.audio_tracks_y_start # Default to track 1
- self.drag_over_audio_rect = QRectF(x, audio_y, width, self.TRACK_HEIGHT)
+ video_y, audio_y = -1, -1
+
+ if media_type in ['video', 'image']:
+ if track_type == 'video':
+ visual_index = self.timeline.num_video_tracks - track_index
+ video_y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT
+ if has_audio:
+ audio_y = self.audio_tracks_y_start
+ elif track_type == 'audio' and has_audio:
+ visual_index = track_index - 1
+ audio_y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT
+ video_y = self.video_tracks_y_start + (self.timeline.num_video_tracks - 1) * self.TRACK_HEIGHT
+
+ elif media_type == 'audio':
+ if track_type == 'audio':
+ visual_index = track_index - 1
+ audio_y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT
- elif track_type == 'audio':
- visual_index = track_index - 1
- y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT
- self.drag_over_audio_rect = QRectF(x, y, width, self.TRACK_HEIGHT)
- # For now, just show video on track 1 if dragging onto audio
- video_y = self.video_tracks_y_start + (self.timeline.num_video_tracks - 1) * self.TRACK_HEIGHT
+ if video_y != -1:
self.drag_over_rect = QRectF(x, video_y, width, self.TRACK_HEIGHT)
- else:
- self.drag_over_active = False
- self.highlighted_ghost_track_info = None
- self.highlighted_track_info = None
+ if audio_y != -1:
+ self.drag_over_audio_rect = QRectF(x, audio_y, width, self.TRACK_HEIGHT)
self.update()
@@ -725,6 +738,7 @@ def dropEvent(self, event):
file_path = json_data['path']
duration = json_data['duration']
has_audio = json_data['has_audio']
+ media_type = json_data['media_type']
pos = event.position()
start_sec = self.x_to_sec(pos.x())
@@ -734,19 +748,32 @@ def dropEvent(self, event):
return
drop_track_type, drop_track_index = track_info
- video_track_idx = 1
- audio_track_idx = 1 if has_audio else None
-
- if drop_track_type == 'video':
- video_track_idx = drop_track_index
- elif drop_track_type == 'audio':
- audio_track_idx = drop_track_index
+ video_track_idx = None
+ audio_track_idx = None
+
+ if media_type == 'image':
+ if drop_track_type == 'video':
+ video_track_idx = drop_track_index
+ elif media_type == 'audio':
+ if drop_track_type == 'audio':
+ audio_track_idx = drop_track_index
+ elif media_type == 'video':
+ if drop_track_type == 'video':
+ video_track_idx = drop_track_index
+ if has_audio: audio_track_idx = 1
+ elif drop_track_type == 'audio' and has_audio:
+ audio_track_idx = drop_track_index
+ video_track_idx = 1
+
+ if video_track_idx is None and audio_track_idx is None:
+ return
main_window = self.window()
main_window._add_clip_to_timeline(
source_path=file_path,
timeline_start_sec=start_sec,
duration_sec=duration,
+ media_type=media_type,
clip_start_sec=0,
video_track_index=video_track_idx,
audio_track_index=audio_track_idx
@@ -847,7 +874,8 @@ def startDrag(self, supportedActions):
payload = {
"path": path,
"duration": media_info['duration'],
- "has_audio": media_info['has_audio']
+ "has_audio": media_info['has_audio'],
+ "media_type": media_info['media_type']
}
mime_data.setData('application/x-vnd.video.filepath', QByteArray(json.dumps(payload).encode('utf-8')))
@@ -1049,7 +1077,7 @@ def _connect_signals(self):
self.frame_forward_button.clicked.connect(lambda: self.step_frame(1))
self.playback_timer.timeout.connect(self.advance_playback_frame)
- self.project_media_widget.add_media_requested.connect(self.add_video_clip)
+ self.project_media_widget.add_media_requested.connect(self.add_media_files)
self.project_media_widget.media_removed.connect(self.on_media_removed_from_pool)
self.project_media_widget.add_to_timeline_requested.connect(self.on_add_to_timeline_at_playhead)
@@ -1095,14 +1123,19 @@ def on_add_to_timeline_at_playhead(self, file_path):
playhead_pos = self.timeline_widget.playhead_pos_sec
duration = media_info['duration']
has_audio = media_info['has_audio']
+ media_type = media_info['media_type']
+
+ video_track = 1 if media_type in ['video', 'image'] else None
+ audio_track = 1 if has_audio else None
self._add_clip_to_timeline(
source_path=file_path,
timeline_start_sec=playhead_pos,
duration_sec=duration,
+ media_type=media_type,
clip_start_sec=0.0,
- video_track_index=1,
- audio_track_index=1 if has_audio else None
+ video_track_index=video_track,
+ audio_track_index=audio_track
)
def prune_empty_tracks(self):
@@ -1165,13 +1198,13 @@ def _create_menu_bar(self):
open_action = QAction("&Open Project...", self); open_action.triggered.connect(self.open_project)
self.recent_menu = file_menu.addMenu("Recent")
save_action = QAction("&Save Project As...", self); save_action.triggered.connect(self.save_project_as)
- add_video_action = QAction("&Add Media to Project...", self); add_video_action.triggered.connect(self.add_video_clip)
+ add_media_action = QAction("&Add Media to Project...", self); add_media_action.triggered.connect(self.add_media_files)
export_action = QAction("&Export Video...", self); export_action.triggered.connect(self.export_video)
settings_action = QAction("Se&ttings...", self); settings_action.triggered.connect(self.open_settings_dialog)
exit_action = QAction("E&xit", self); exit_action.triggered.connect(self.close)
file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addSeparator()
file_menu.addAction(save_action)
- file_menu.addSeparator(); file_menu.addAction(add_video_action); file_menu.addAction(export_action)
+ file_menu.addSeparator(); file_menu.addAction(add_media_action); file_menu.addAction(export_action)
file_menu.addSeparator(); file_menu.addAction(settings_action); file_menu.addSeparator(); file_menu.addAction(exit_action)
self._update_recent_files_menu()
@@ -1252,13 +1285,24 @@ def get_frame_data_at_time(self, time_sec):
if not clip_at_time:
return (None, 0, 0)
try:
- clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
- out, _ = (
- ffmpeg
- .input(clip_at_time.source_path, ss=clip_time)
- .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
- .run(capture_stdout=True, quiet=True)
- )
+ w, h = self.project_width, self.project_height
+ if clip_at_time.media_type == 'image':
+ out, _ = (
+ ffmpeg
+ .input(clip_at_time.source_path)
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True)
+ )
+ else:
+ clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
+ out, _ = (
+ ffmpeg
+ .input(clip_at_time.source_path, ss=clip_time)
+ .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True)
+ )
return (out, self.project_width, self.project_height)
except ffmpeg.Error as e:
print(f"Error extracting frame data: {e.stderr}")
@@ -1269,8 +1313,19 @@ def get_frame_at_time(self, time_sec):
black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
if not clip_at_time: return black_pixmap
try:
- clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
- out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time).output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24').run(capture_stdout=True, quiet=True))
+ w, h = self.project_width, self.project_height
+ if clip_at_time.media_type == 'image':
+ out, _ = (
+ ffmpeg.input(clip_at_time.source_path)
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True)
+ )
+ else:
+ clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
+ out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time).output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24').run(capture_stdout=True, quiet=True))
+
image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888)
return QPixmap.fromImage(image)
except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap
@@ -1303,14 +1358,32 @@ def advance_playback_frame(self):
frame_duration = 1.0 / self.project_fps
new_time = self.timeline_widget.playhead_pos_sec + frame_duration
if new_time > self.timeline.get_total_duration(): self.stop_playback(); return
- self.timeline_widget.playhead_pos_sec = new_time; self.timeline_widget.update()
+
+ self.timeline_widget.playhead_pos_sec = new_time
+ self.timeline_widget.update()
+
clip_at_new_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec), None)
+
if not clip_at_new_time:
self._stop_playback_stream()
black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
- self.preview_widget.setPixmap(scaled_pixmap); return
- if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: self._start_playback_stream_at(new_time)
+ self.preview_widget.setPixmap(scaled_pixmap)
+ return
+
+ if clip_at_new_time.media_type == 'image':
+ if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
+ self._stop_playback_stream()
+ self.playback_clip = clip_at_new_time
+ frame_pixmap = self.get_frame_at_time(new_time)
+ if frame_pixmap:
+ scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
+ self.preview_widget.setPixmap(scaled_pixmap)
+ return
+
+ if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
+ self._start_playback_stream_at(new_time)
+
if self.playback_process:
frame_size = self.project_width * self.project_height * 3
frame_bytes = self.playback_process.stdout.read(frame_size)
@@ -1319,7 +1392,8 @@ def advance_playback_frame(self):
pixmap = QPixmap.fromImage(image)
scaled_pixmap = pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
self.preview_widget.setPixmap(scaled_pixmap)
- else: self._stop_playback_stream()
+ else:
+ self._stop_playback_stream()
def _load_settings(self):
self.settings_file_was_loaded = False
@@ -1386,7 +1460,7 @@ def save_project_as(self):
if not path: return
project_data = {
"media_pool": self.media_pool,
- "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "group_id": c.group_id} for c in self.timeline.clips],
+ "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips],
"settings": {"num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks}
}
try:
@@ -1410,6 +1484,14 @@ def _load_project_from_path(self, path):
for clip_data in project_data["clips"]:
if not os.path.exists(clip_data["source_path"]):
self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.new_project(); return
+
+ if 'media_type' not in clip_data:
+ ext = os.path.splitext(clip_data['source_path'])[1].lower()
+ if ext in ['.mp3', '.wav', '.m4a', '.aac']:
+ clip_data['media_type'] = 'audio'
+ else:
+ clip_data['media_type'] = 'video'
+
self.timeline.add_clip(TimelineClip(**clip_data))
project_settings = project_data.get("settings", {})
@@ -1417,7 +1499,9 @@ def _load_project_from_path(self, path):
self.timeline.num_audio_tracks = project_settings.get("num_audio_tracks", 1)
self.current_project_path = path
- if self.timeline.clips: self._set_project_properties_from_clip(self.timeline.clips[0].source_path)
+ video_clips = [c for c in self.timeline.clips if c.media_type == 'video']
+ if video_clips:
+ self._set_project_properties_from_clip(video_clips[0].source_path)
self.prune_empty_tracks()
self.timeline_widget.update(); self.stop_playback()
self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.")
@@ -1445,14 +1529,31 @@ def _add_media_to_pool(self, file_path):
if file_path in self.media_pool: return True
try:
self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents()
- probe = ffmpeg.probe(file_path)
- video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
- if not video_stream: raise ValueError("No video stream found.")
- duration = float(video_stream.get('duration', probe['format'].get('duration', 0)))
- has_audio = any(s['codec_type'] == 'audio' for s in probe['streams'])
+ file_ext = os.path.splitext(file_path)[1].lower()
+ media_info = {}
- self.media_properties[file_path] = {'duration': duration, 'has_audio': has_audio}
+ if file_ext in ['.png', '.jpg', '.jpeg']:
+ media_info['media_type'] = 'image'
+ media_info['duration'] = 5.0
+ media_info['has_audio'] = False
+ else:
+ probe = ffmpeg.probe(file_path)
+ video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
+ audio_stream = next((s for s in probe['streams'] if s['codec_type'] == 'audio'), None)
+
+ if video_stream:
+ media_info['media_type'] = 'video'
+ media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0)))
+ media_info['has_audio'] = audio_stream is not None
+ elif audio_stream:
+ media_info['media_type'] = 'audio'
+ media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0)))
+ media_info['has_audio'] = True
+ else:
+ raise ValueError("No video or audio stream found.")
+
+ self.media_properties[file_path] = media_info
self.media_pool.append(file_path)
self.project_media_widget.add_media_item(file_path)
@@ -1477,33 +1578,42 @@ def on_media_removed_from_pool(self, file_path):
self.undo_stack.push(command)
- def add_video_clip(self):
- file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.mov *.avi)")
- if not file_path: return
+ def add_media_files(self):
+ file_paths, _ = QFileDialog.getOpenFileNames(self, "Open Media Files", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)")
+ if not file_paths: return
+ self.media_dock.show()
- if not self.timeline.clips and not self.media_pool:
- if not self._set_project_properties_from_clip(file_path):
- self.status_label.setText("Error: Could not determine video properties from file."); return
-
- if self._add_media_to_pool(file_path):
- self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.")
- else:
- self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.")
+ first_video_added = any(c.media_type == 'video' for c in self.timeline.clips)
+
+ for file_path in file_paths:
+ if not self.timeline.clips and not self.media_pool and not first_video_added:
+ ext = os.path.splitext(file_path)[1].lower()
+ if ext not in ['.png', '.jpg', '.jpeg', '.mp3', '.wav']:
+ if self._set_project_properties_from_clip(file_path):
+ first_video_added = True
+ else:
+ self.status_label.setText("Error: Could not determine video properties from file.");
+ continue
+
+ if self._add_media_to_pool(file_path):
+ self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.")
+ else:
+ self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.")
- def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, clip_start_sec=0.0, video_track_index=None, audio_track_index=None):
+ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None):
old_state = self._get_current_timeline_state()
group_id = str(uuid.uuid4())
if video_track_index is not None:
if video_track_index > self.timeline.num_video_tracks:
self.timeline.num_video_tracks = video_track_index
- video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', group_id)
+ video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', media_type, group_id)
self.timeline.add_clip(video_clip)
if audio_track_index is not None:
if audio_track_index > self.timeline.num_audio_tracks:
self.timeline.num_audio_tracks = audio_track_index
- audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', group_id)
+ audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', media_type, group_id)
self.timeline.add_clip(audio_clip)
new_state = self._get_current_timeline_state()
@@ -1517,7 +1627,7 @@ def _split_at_time(self, clip_to_split, time_sec, new_group_id=None):
orig_dur = clip_to_split.duration_sec
group_id_for_new_clip = new_group_id if new_group_id is not None else clip_to_split.group_id
- new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, group_id_for_new_clip)
+ new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, clip_to_split.media_type, group_id_for_new_clip)
clip_to_split.duration_sec = split_point
self.timeline.add_clip(new_clip)
return True
@@ -1707,8 +1817,8 @@ def export_video(self):
w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration()
sample_rate, channel_layout = '44100', 'stereo'
- input_files = {clip.source_path: ffmpeg.input(clip.source_path) for clip in self.timeline.clips}
-
+ input_streams = {}
+
last_video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi')
for i in range(self.timeline.num_video_tracks):
track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec)
@@ -1722,8 +1832,18 @@ def export_video(self):
if gap > 0.01:
track_segments.append(ffmpeg.input(f'color=c=black@0.0:s={w}x{h}:r={fr_str}:d={gap}', f='lavfi').filter('format', pix_fmts='rgba'))
- v_seg = (input_files[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS')
- .filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba'))
+ if clip.source_path not in input_streams:
+ if clip.media_type == 'image':
+ input_streams[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps)
+ else:
+ input_streams[clip.source_path] = ffmpeg.input(clip.source_path)
+
+ if clip.media_type == 'image':
+ v_seg = (input_streams[clip.source_path].video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS'))
+ else: # video
+ v_seg = (input_streams[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS'))
+
+ v_seg = (v_seg.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba'))
track_segments.append(v_seg)
last_end = clip.timeline_end_sec
@@ -1740,11 +1860,14 @@ def export_video(self):
last_end = track_clips[0].timeline_start_sec
for clip in track_clips:
+ if clip.source_path not in input_streams:
+ input_streams[clip.source_path] = ffmpeg.input(clip.source_path)
+
gap = clip.timeline_start_sec - last_end
if gap > 0.01:
track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi'))
- a_seg = input_files[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS')
+ a_seg = input_streams[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS')
track_segments.append(a_seg)
last_end = clip.timeline_end_sec
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index 3c8ef6157..d66dda353 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -291,6 +291,7 @@ def insert_generated_clip(self, video_path):
timeline_start_sec=start_sec,
clip_start_sec=0,
duration_sec=actual_duration,
+ media_type='video',
video_track_index=new_track_index,
audio_track_index=None
)
@@ -312,6 +313,7 @@ def insert_generated_clip(self, video_path):
timeline_start_sec=start_sec,
clip_start_sec=0,
duration_sec=actual_duration,
+ media_type='video',
video_track_index=1,
audio_track_index=None
)
From a5b51dfb1207ebd64f04f96e61d569029aa23294 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 02:10:56 +1100
Subject: [PATCH 103/155] add ability to resize clips
---
videoeditor/main.py | 123 ++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 120 insertions(+), 3 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 49ac27b13..3ce1bccde 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -11,7 +11,7 @@
QScrollArea, QFrame, QProgressBar, QDialog,
QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox)
from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction,
- QPixmap, QImage, QDrag)
+ QPixmap, QImage, QDrag, QCursor)
from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread,
pyqtSignal, QTimer, QByteArray, QMimeData)
@@ -82,6 +82,7 @@ class TimelineWidget(QWidget):
HEADER_WIDTH = 120
TRACK_HEIGHT = 50
AUDIO_TRACKS_SEPARATOR_Y = 15
+ RESIZE_HANDLE_WIDTH = 8
split_requested = pyqtSignal(object)
delete_clip_requested = pyqtSignal(object)
@@ -124,6 +125,10 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.drag_selection_start_values = None
self.drag_start_state = None
+ self.resizing_clip = None
+ self.resize_edge = None # 'left' or 'right'
+ self.resize_start_pos = QPoint()
+
self.highlighted_track_info = None
self.highlighted_ghost_track_info = None
self.add_video_track_btn_rect = QRect()
@@ -499,8 +504,28 @@ def mousePressEvent(self, event: QMouseEvent):
self.dragging_playhead = False
self.dragging_selection_region = None
self.creating_selection_region = False
+ self.resizing_clip = None
+ self.resize_edge = None
self.drag_original_clip_states.clear()
+ # Check for resize handles first
+ for clip in reversed(self.timeline.clips):
+ clip_rect = self.get_clip_rect(clip)
+ if abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y())):
+ self.resizing_clip = clip
+ self.resize_edge = 'left'
+ break
+ elif abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y())):
+ self.resizing_clip = clip
+ self.resize_edge = 'right'
+ break
+
+ if self.resizing_clip:
+ self.drag_start_state = self.window()._get_current_timeline_state()
+ self.resize_start_pos = event.pos()
+ self.update()
+ return
+
for clip in reversed(self.timeline.clips):
clip_rect = self.get_clip_rect(clip)
if clip_rect.contains(QPointF(event.pos())):
@@ -538,6 +563,77 @@ def mousePressEvent(self, event: QMouseEvent):
self.update()
def mouseMoveEvent(self, event: QMouseEvent):
+ if self.resizing_clip:
+ delta_x = event.pos().x() - self.resize_start_pos.x()
+ time_delta = delta_x / self.pixels_per_second
+ min_duration = 1.0 / self.project_fps
+
+ media_props = self.window().media_properties.get(self.resizing_clip.source_path)
+ source_duration = media_props['duration'] if media_props else float('inf')
+
+ if self.resize_edge == 'left':
+ original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec
+ original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec
+ original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec
+
+ new_start_sec = original_start + time_delta
+
+ # Clamp to not move past the end of the clip
+ if new_start_sec > original_start + original_duration - min_duration:
+ new_start_sec = original_start + original_duration - min_duration
+
+ # Clamp to zero
+ new_start_sec = max(0, new_start_sec)
+
+ # For video/audio, don't allow extending beyond the source start
+ if self.resizing_clip.media_type != 'image':
+ if new_start_sec < original_start - original_clip_start:
+ new_start_sec = original_start - original_clip_start
+
+ new_duration = (original_start + original_duration) - new_start_sec
+ # FIX: Correctly calculate the new clip start by ADDING the time shift
+ new_clip_start = original_clip_start + (new_start_sec - original_start)
+
+ if new_duration < min_duration:
+ new_duration = min_duration
+ new_start_sec = (original_start + original_duration) - new_duration
+ # FIX: Apply the same corrected logic here after clamping duration
+ new_clip_start = original_clip_start + (new_start_sec - original_start)
+
+ self.resizing_clip.timeline_start_sec = new_start_sec
+ self.resizing_clip.duration_sec = new_duration
+ self.resizing_clip.clip_start_sec = new_clip_start
+
+ elif self.resize_edge == 'right':
+ original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec
+ new_duration = original_duration + time_delta
+
+ if new_duration < min_duration:
+ new_duration = min_duration
+
+ # For video/audio, don't allow extending beyond the source end
+ if self.resizing_clip.media_type != 'image':
+ if self.resizing_clip.clip_start_sec + new_duration > source_duration:
+ new_duration = source_duration - self.resizing_clip.clip_start_sec
+
+ self.resizing_clip.duration_sec = new_duration
+
+ self.update()
+ return
+
+ # Update cursor for resizing
+ if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region:
+ cursor_set = False
+ for clip in self.timeline.clips:
+ clip_rect = self.get_clip_rect(clip)
+ if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \
+ (abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y()))):
+ self.setCursor(Qt.CursorShape.SizeHorCursor)
+ cursor_set = True
+ break
+ if not cursor_set:
+ self.unsetCursor()
+
if self.creating_selection_region:
current_sec = self.x_to_sec(event.pos().x())
start = min(self.selection_drag_start_sec, current_sec)
@@ -613,6 +709,17 @@ def mouseMoveEvent(self, event: QMouseEvent):
def mouseReleaseEvent(self, event: QMouseEvent):
if event.button() == Qt.MouseButton.LeftButton:
+ if self.resizing_clip:
+ new_state = self.window()._get_current_timeline_state()
+ command = TimelineStateChangeCommand("Resize Clip", self.timeline, *self.drag_start_state, *new_state)
+ command.undo()
+ self.window().undo_stack.push(command)
+ self.resizing_clip = None
+ self.resize_edge = None
+ self.drag_start_state = None
+ self.update()
+ return
+
if self.creating_selection_region:
self.creating_selection_region = False
if self.selection_regions:
@@ -1249,8 +1356,12 @@ def _start_playback_stream_at(self, time_sec):
if not clip: return
self.playback_clip = clip
clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec
+ w, h = self.project_width, self.project_height
try:
- args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time).output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile())
+ args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time)
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile())
self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
except Exception as e:
print(f"Failed to start playback stream: {e}"); self._stop_playback_stream()
@@ -1300,6 +1411,8 @@ def get_frame_data_at_time(self, time_sec):
out, _ = (
ffmpeg
.input(clip_at_time.source_path, ss=clip_time)
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
@@ -1324,7 +1437,11 @@ def get_frame_at_time(self, time_sec):
)
else:
clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
- out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time).output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24').run(capture_stdout=True, quiet=True))
+ out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time)
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True))
image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888)
return QPixmap.fromImage(image)
From 25bcc33827d44f00070ab945fac9715e550cf228 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 02:37:49 +1100
Subject: [PATCH 104/155] add playhead snapping for all functions
---
videoeditor/main.py | 87 +++++++++++++++++++++++++++++++--------------
1 file changed, 61 insertions(+), 26 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 3ce1bccde..b3d3828e4 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -83,6 +83,7 @@ class TimelineWidget(QWidget):
TRACK_HEIGHT = 50
AUDIO_TRACKS_SEPARATOR_Y = 15
RESIZE_HANDLE_WIDTH = 8
+ SNAP_THRESHOLD_PIXELS = 8
split_requested = pyqtSignal(object)
delete_clip_requested = pyqtSignal(object)
@@ -553,7 +554,11 @@ def mousePressEvent(self, event: QMouseEvent):
if is_in_track_area:
self.creating_selection_region = True
- self.selection_drag_start_sec = self.x_to_sec(event.pos().x())
+ playhead_x = self.sec_to_x(self.playhead_pos_sec)
+ if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS:
+ self.selection_drag_start_sec = self.playhead_pos_sec
+ else:
+ self.selection_drag_start_sec = self.x_to_sec(event.pos().x())
self.selection_regions.append([self.selection_drag_start_sec, self.selection_drag_start_sec])
elif is_on_timescale:
self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x()))
@@ -567,6 +572,8 @@ def mouseMoveEvent(self, event: QMouseEvent):
delta_x = event.pos().x() - self.resize_start_pos.x()
time_delta = delta_x / self.pixels_per_second
min_duration = 1.0 / self.project_fps
+ playhead_time = self.playhead_pos_sec
+ snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second
media_props = self.window().media_properties.get(self.resizing_clip.source_path)
source_duration = media_props['duration'] if media_props else float('inf')
@@ -575,29 +582,27 @@ def mouseMoveEvent(self, event: QMouseEvent):
original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec
original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec
original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec
-
- new_start_sec = original_start + time_delta
-
- # Clamp to not move past the end of the clip
+ true_new_start_sec = original_start + time_delta
+ if abs(true_new_start_sec - playhead_time) < snap_time_delta:
+ new_start_sec = playhead_time
+ else:
+ new_start_sec = true_new_start_sec
+
if new_start_sec > original_start + original_duration - min_duration:
new_start_sec = original_start + original_duration - min_duration
-
- # Clamp to zero
+
new_start_sec = max(0, new_start_sec)
- # For video/audio, don't allow extending beyond the source start
if self.resizing_clip.media_type != 'image':
if new_start_sec < original_start - original_clip_start:
new_start_sec = original_start - original_clip_start
new_duration = (original_start + original_duration) - new_start_sec
- # FIX: Correctly calculate the new clip start by ADDING the time shift
new_clip_start = original_clip_start + (new_start_sec - original_start)
if new_duration < min_duration:
new_duration = min_duration
new_start_sec = (original_start + original_duration) - new_duration
- # FIX: Apply the same corrected logic here after clamping duration
new_clip_start = original_clip_start + (new_start_sec - original_start)
self.resizing_clip.timeline_start_sec = new_start_sec
@@ -605,13 +610,20 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.resizing_clip.clip_start_sec = new_clip_start
elif self.resize_edge == 'right':
+ original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec
original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec
- new_duration = original_duration + time_delta
+
+ true_new_duration = original_duration + time_delta
+ true_new_end_time = original_start + true_new_duration
+
+ if abs(true_new_end_time - playhead_time) < snap_time_delta:
+ new_duration = playhead_time - original_start
+ else:
+ new_duration = true_new_duration
if new_duration < min_duration:
new_duration = min_duration
-
- # For video/audio, don't allow extending beyond the source end
+
if self.resizing_clip.media_type != 'image':
if self.resizing_clip.clip_start_sec + new_duration > source_duration:
new_duration = source_duration - self.resizing_clip.clip_start_sec
@@ -620,17 +632,26 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.update()
return
-
- # Update cursor for resizing
+
if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region:
cursor_set = False
- for clip in self.timeline.clips:
- clip_rect = self.get_clip_rect(clip)
- if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \
- (abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y()))):
- self.setCursor(Qt.CursorShape.SizeHorCursor)
- cursor_set = True
- break
+ ### ADD BEGIN ###
+ # Check for playhead proximity for selection snapping
+ playhead_x = self.sec_to_x(self.playhead_pos_sec)
+ is_in_track_area = event.pos().y() > self.TIMESCALE_HEIGHT and event.pos().x() > self.HEADER_WIDTH
+ if is_in_track_area and abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS:
+ self.setCursor(Qt.CursorShape.SizeHorCursor)
+ cursor_set = True
+
+ if not cursor_set:
+ ### ADD END ###
+ for clip in self.timeline.clips:
+ clip_rect = self.get_clip_rect(clip)
+ if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \
+ (abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y()))):
+ self.setCursor(Qt.CursorShape.SizeHorCursor)
+ cursor_set = True
+ break
if not cursor_set:
self.unsetCursor()
@@ -683,7 +704,18 @@ def mouseMoveEvent(self, event: QMouseEvent):
delta_x = event.pos().x() - self.drag_start_pos.x()
time_delta = delta_x / self.pixels_per_second
- new_start_time = original_start_sec + time_delta
+ true_new_start_time = original_start_sec + time_delta
+
+ playhead_time = self.playhead_pos_sec
+ snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second
+
+ new_start_time = true_new_start_time
+ true_new_end_time = true_new_start_time + self.dragging_clip.duration_sec
+
+ if abs(true_new_start_time - playhead_time) < snap_time_delta:
+ new_start_time = playhead_time
+ elif abs(true_new_end_time - playhead_time) < snap_time_delta:
+ new_start_time = playhead_time - self.dragging_clip.duration_sec
for other_clip in self.timeline.clips:
if other_clip.id == self.dragging_clip.id: continue
@@ -696,8 +728,11 @@ def mouseMoveEvent(self, event: QMouseEvent):
new_start_time + self.dragging_clip.duration_sec > other_clip.timeline_start_sec)
if is_overlapping:
- if time_delta > 0: new_start_time = other_clip.timeline_start_sec - self.dragging_clip.duration_sec
- else: new_start_time = other_clip.timeline_end_sec
+ movement_direction = true_new_start_time - original_start_sec
+ if movement_direction > 0:
+ new_start_time = other_clip.timeline_start_sec - self.dragging_clip.duration_sec
+ else:
+ new_start_time = other_clip.timeline_end_sec
break
final_start_time = max(0, new_start_time)
@@ -725,7 +760,7 @@ def mouseReleaseEvent(self, event: QMouseEvent):
if self.selection_regions:
start, end = self.selection_regions[-1]
if (end - start) * self.pixels_per_second < 2:
- self.selection_regions.pop()
+ self.clear_all_regions()
if self.dragging_selection_region:
self.dragging_selection_region = None
From dc68e8682ad27a1d4e728dee4f34d76ef9096cc5 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 03:11:58 +1100
Subject: [PATCH 105/155] added feature to unlink and relink audio tracks
---
videoeditor/main.py | 88 +++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 85 insertions(+), 3 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index b3d3828e4..8d9e27d30 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -569,6 +569,7 @@ def mousePressEvent(self, event: QMouseEvent):
def mouseMoveEvent(self, event: QMouseEvent):
if self.resizing_clip:
+ linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None)
delta_x = event.pos().x() - self.resize_start_pos.x()
time_delta = delta_x / self.pixels_per_second
min_duration = 1.0 / self.project_fps
@@ -608,6 +609,10 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.resizing_clip.timeline_start_sec = new_start_sec
self.resizing_clip.duration_sec = new_duration
self.resizing_clip.clip_start_sec = new_clip_start
+ if linked_clip:
+ linked_clip.timeline_start_sec = new_start_sec
+ linked_clip.duration_sec = new_duration
+ linked_clip.clip_start_sec = new_clip_start
elif self.resize_edge == 'right':
original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec
@@ -629,14 +634,14 @@ def mouseMoveEvent(self, event: QMouseEvent):
new_duration = source_duration - self.resizing_clip.clip_start_sec
self.resizing_clip.duration_sec = new_duration
+ if linked_clip:
+ linked_clip.duration_sec = new_duration
self.update()
return
if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region:
cursor_set = False
- ### ADD BEGIN ###
- # Check for playhead proximity for selection snapping
playhead_x = self.sec_to_x(self.playhead_pos_sec)
is_in_track_area = event.pos().y() > self.TIMESCALE_HEIGHT and event.pos().x() > self.HEADER_WIDTH
if is_in_track_area and abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS:
@@ -644,7 +649,6 @@ def mouseMoveEvent(self, event: QMouseEvent):
cursor_set = True
if not cursor_set:
- ### ADD END ###
for clip in self.timeline.clips:
clip_rect = self.get_clip_rect(clip)
if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \
@@ -952,6 +956,18 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'):
if clip_at_pos:
if not menu.isEmpty(): menu.addSeparator()
+
+ linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_at_pos.group_id and c.id != clip_at_pos.id), None)
+ if linked_clip:
+ unlink_action = menu.addAction("Unlink Audio Track")
+ unlink_action.triggered.connect(lambda: self.window().unlink_clip_pair(clip_at_pos))
+ else:
+ media_info = self.window().media_properties.get(clip_at_pos.source_path)
+ if (clip_at_pos.track_type == 'video' and
+ media_info and media_info.get('has_audio')):
+ relink_action = menu.addAction("Relink Audio Track")
+ relink_action.triggered.connect(lambda: self.window().relink_clip_audio(clip_at_pos))
+
split_action = menu.addAction("Split Clip")
delete_action = menu.addAction("Delete Clip")
playhead_time = self.playhead_pos_sec
@@ -1825,6 +1841,72 @@ def delete_clip(self, clip_to_delete):
self.undo_stack.push(command)
self.prune_empty_tracks()
+ def unlink_clip_pair(self, clip_to_unlink):
+ old_state = self._get_current_timeline_state()
+
+ linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_unlink.group_id and c.id != clip_to_unlink.id), None)
+
+ if linked_clip:
+ clip_to_unlink.group_id = str(uuid.uuid4())
+ linked_clip.group_id = str(uuid.uuid4())
+
+ new_state = self._get_current_timeline_state()
+ command = TimelineStateChangeCommand("Unlink Clips", self.timeline, *old_state, *new_state)
+ command.undo()
+ self.undo_stack.push(command)
+ self.status_label.setText("Clips unlinked.")
+ else:
+ self.status_label.setText("Could not find a clip to unlink.")
+
+ def relink_clip_audio(self, video_clip):
+ def action():
+ media_info = self.media_properties.get(video_clip.source_path)
+ if not media_info or not media_info.get('has_audio'):
+ self.status_label.setText("Source media has no audio to relink.")
+ return
+
+ target_audio_track = 1
+ new_audio_start = video_clip.timeline_start_sec
+ new_audio_end = video_clip.timeline_end_sec
+
+ conflicting_clips = [
+ c for c in self.timeline.clips
+ if c.track_type == 'audio' and c.track_index == target_audio_track and
+ c.timeline_start_sec < new_audio_end and c.timeline_end_sec > new_audio_start
+ ]
+
+ for conflict_clip in conflicting_clips:
+ found_spot = False
+ for check_track_idx in range(target_audio_track + 1, self.timeline.num_audio_tracks + 2):
+ is_occupied = any(
+ other.timeline_start_sec < conflict_clip.timeline_end_sec and other.timeline_end_sec > conflict_clip.timeline_start_sec
+ for other in self.timeline.clips
+ if other.id != conflict_clip.id and other.track_type == 'audio' and other.track_index == check_track_idx
+ )
+
+ if not is_occupied:
+ if check_track_idx > self.timeline.num_audio_tracks:
+ self.timeline.num_audio_tracks = check_track_idx
+
+ conflict_clip.track_index = check_track_idx
+ found_spot = True
+ break
+
+ new_audio_clip = TimelineClip(
+ source_path=video_clip.source_path,
+ timeline_start_sec=video_clip.timeline_start_sec,
+ clip_start_sec=video_clip.clip_start_sec,
+ duration_sec=video_clip.duration_sec,
+ track_index=target_audio_track,
+ track_type='audio',
+ media_type=video_clip.media_type,
+ group_id=video_clip.group_id
+ )
+ self.timeline.add_clip(new_audio_clip)
+ self.status_label.setText("Audio relinked.")
+
+ self._perform_complex_timeline_change("Relink Audio", action)
+
def _perform_complex_timeline_change(self, description, change_function):
old_state = self._get_current_timeline_state()
change_function()
From bea0f48c70276e6472ba0a57fbb3953d9854e869 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 03:43:35 +1100
Subject: [PATCH 106/155] add ability to add media directly to timeline
---
videoeditor/main.py | 134 +++++++++++++++++++++++++++++++-------------
1 file changed, 96 insertions(+), 38 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 8d9e27d30..c4d3e9268 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -107,7 +107,7 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.scroll_area = None
self.pixels_per_second = 50.0
- self.max_pixels_per_second = 1.0 # Will be updated by set_project_fps
+ self.max_pixels_per_second = 1.0
self.project_fps = 25.0
self.set_project_fps(project_fps)
@@ -121,13 +121,13 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.creating_selection_region = False
self.dragging_selection_region = None
self.drag_start_pos = QPoint()
- self.drag_original_clip_states = {} # Store {'clip_id': (start_sec, track_index)}
+ self.drag_original_clip_states = {}
self.selection_drag_start_sec = 0.0
self.drag_selection_start_values = None
self.drag_start_state = None
self.resizing_clip = None
- self.resize_edge = None # 'left' or 'right'
+ self.resize_edge = None
self.resize_start_pos = QPoint()
self.highlighted_track_info = None
@@ -146,7 +146,6 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
def set_project_fps(self, fps):
self.project_fps = fps if fps > 0 else 25.0
- # Set max zoom to be 10 pixels per frame
self.max_pixels_per_second = self.project_fps * 20
self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second)
self.update()
@@ -263,7 +262,6 @@ def _format_timecode(self, seconds):
precision = 2 if s < 1 else 1 if s < 10 else 0
val = f"{s:.{precision}f}"
- # Remove trailing .0 or .00
if '.' in val: val = val.rstrip('0').rstrip('.')
return f"{sign}{val}s"
@@ -382,11 +380,11 @@ def draw_tracks_and_clips(self, painter):
for clip in self.timeline.clips:
clip_rect = self.get_clip_rect(clip)
- base_color = QColor("#46A") # Default video
+ base_color = QColor("#46A")
if clip.media_type == 'image':
- base_color = QColor("#4A6") # Greenish for images
+ base_color = QColor("#4A6")
elif clip.track_type == 'audio':
- base_color = QColor("#48C") # Bluish for audio
+ base_color = QColor("#48C")
color = QColor("#5A9") if self.dragging_clip and self.dragging_clip.id == clip.id else base_color
painter.fillRect(clip_rect, color)
@@ -415,25 +413,21 @@ def draw_playhead(self, painter):
painter.drawLine(playhead_x, 0, playhead_x, self.height())
def y_to_track_info(self, y):
- # Check for ghost video track
if self.TIMESCALE_HEIGHT <= y < self.video_tracks_y_start:
return ('video', self.timeline.num_video_tracks + 1)
- # Check for existing video tracks
video_tracks_end_y = self.video_tracks_y_start + self.timeline.num_video_tracks * self.TRACK_HEIGHT
if self.video_tracks_y_start <= y < video_tracks_end_y:
visual_index = (y - self.video_tracks_y_start) // self.TRACK_HEIGHT
track_index = self.timeline.num_video_tracks - visual_index
return ('video', track_index)
-
- # Check for existing audio tracks
+
audio_tracks_end_y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT
if self.audio_tracks_y_start <= y < audio_tracks_end_y:
visual_index = (y - self.audio_tracks_y_start) // self.TRACK_HEIGHT
track_index = visual_index + 1
return ('audio', track_index)
-
- # Check for ghost audio track
+
add_audio_btn_y_start = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT
add_audio_btn_y_end = add_audio_btn_y_start + self.TRACK_HEIGHT
if add_audio_btn_y_start <= y < add_audio_btn_y_end:
@@ -468,27 +462,19 @@ def wheelEvent(self, event: QMouseEvent):
if delta > 0: new_pps = old_pps * zoom_factor
else: new_pps = old_pps / zoom_factor
- min_pps = 1 / (3600 * 10) # Min zoom of 1px per 10 hours
+ min_pps = 1 / (3600 * 10)
new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second))
if abs(new_pps - old_pps) < 1e-9:
return
self.pixels_per_second = new_pps
- # This will trigger a paintEvent, which resizes the widget.
- # The scroll area will update its scrollbar ranges in response.
self.update()
-
- # The new absolute x-coordinate for the time under the cursor
+
new_mouse_x_abs = self.sec_to_x(time_at_cursor)
-
- # The amount the content "shifted" at the cursor's location
shift_amount = new_mouse_x_abs - mouse_x_abs
-
- # Adjust the scrollbar by this shift amount to keep the content under the cursor
new_scroll_value = scrollbar.value() + shift_amount
scrollbar.setValue(int(new_scroll_value))
-
event.accept()
def mousePressEvent(self, event: QMouseEvent):
@@ -509,7 +495,6 @@ def mousePressEvent(self, event: QMouseEvent):
self.resize_edge = None
self.drag_original_clip_states.clear()
- # Check for resize handles first
for clip in reversed(self.timeline.clips):
clip_rect = self.get_clip_rect(clip)
if abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y())):
@@ -1356,13 +1341,18 @@ def _create_menu_bar(self):
open_action = QAction("&Open Project...", self); open_action.triggered.connect(self.open_project)
self.recent_menu = file_menu.addMenu("Recent")
save_action = QAction("&Save Project As...", self); save_action.triggered.connect(self.save_project_as)
+ add_media_to_timeline_action = QAction("Add Media to &Timeline...", self)
+ add_media_to_timeline_action.triggered.connect(self.add_media_to_timeline)
add_media_action = QAction("&Add Media to Project...", self); add_media_action.triggered.connect(self.add_media_files)
export_action = QAction("&Export Video...", self); export_action.triggered.connect(self.export_video)
settings_action = QAction("Se&ttings...", self); settings_action.triggered.connect(self.open_settings_dialog)
exit_action = QAction("E&xit", self); exit_action.triggered.connect(self.close)
file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addSeparator()
file_menu.addAction(save_action)
- file_menu.addSeparator(); file_menu.addAction(add_media_action); file_menu.addAction(export_action)
+ file_menu.addSeparator()
+ file_menu.addAction(add_media_to_timeline_action)
+ file_menu.addAction(add_media_action)
+ file_menu.addAction(export_action)
file_menu.addSeparator(); file_menu.addAction(settings_action); file_menu.addSeparator(); file_menu.addAction(exit_action)
self._update_recent_files_menu()
@@ -1565,7 +1555,7 @@ def advance_playback_frame(self):
def _load_settings(self):
self.settings_file_was_loaded = False
- defaults = {"window_visibility": {"project_media": True}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True}
+ defaults = {"window_visibility": {"project_media": False}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True}
if os.path.exists(self.settings_file):
try:
with open(self.settings_file, "r") as f: self.settings = json.load(f)
@@ -1591,7 +1581,7 @@ def _apply_loaded_settings(self):
visibility_settings = self.settings.get("window_visibility", {})
for key, data in self.managed_widgets.items():
if data.get('plugin'): continue
- is_visible = visibility_settings.get(key, True)
+ is_visible = visibility_settings.get(key, False if key == 'project_media' else True)
if data['widget'] is not self.preview_widget:
data['widget'].setVisible(is_visible)
if data['action']: data['action'].setChecked(is_visible)
@@ -1745,12 +1735,13 @@ def on_media_removed_from_pool(self, file_path):
command.undo()
self.undo_stack.push(command)
+ def _add_media_files_to_project(self, file_paths):
+ if not file_paths:
+ return []
- def add_media_files(self):
- file_paths, _ = QFileDialog.getOpenFileNames(self, "Open Media Files", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)")
- if not file_paths: return
self.media_dock.show()
+ added_files = []
first_video_added = any(c.media_type == 'video' for c in self.timeline.clips)
for file_path in file_paths:
@@ -1760,13 +1751,81 @@ def add_media_files(self):
if self._set_project_properties_from_clip(file_path):
first_video_added = True
else:
- self.status_label.setText("Error: Could not determine video properties from file.");
+ self.status_label.setText("Error: Could not determine video properties from file.")
continue
if self._add_media_to_pool(file_path):
- self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.")
- else:
- self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.")
+ added_files.append(file_path)
+
+ return added_files
+
+ def add_media_to_timeline(self):
+ file_paths, _ = QFileDialog.getOpenFileNames(self, "Add Media to Timeline", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)")
+ if not file_paths:
+ return
+
+ added_files = self._add_media_files_to_project(file_paths)
+ if not added_files:
+ return
+
+ playhead_pos = self.timeline_widget.playhead_pos_sec
+
+ def add_clips_action():
+ for file_path in added_files:
+ media_info = self.media_properties.get(file_path)
+ if not media_info: continue
+
+ duration = media_info['duration']
+ media_type = media_info['media_type']
+ has_audio = media_info['has_audio']
+
+ clip_start_time = playhead_pos
+ clip_end_time = playhead_pos + duration
+
+ video_track_index = None
+ audio_track_index = None
+
+ if media_type in ['video', 'image']:
+ for i in range(1, self.timeline.num_video_tracks + 2):
+ is_occupied = any(
+ c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time
+ for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i
+ )
+ if not is_occupied:
+ video_track_index = i
+ break
+
+ if has_audio:
+ for i in range(1, self.timeline.num_audio_tracks + 2):
+ is_occupied = any(
+ c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time
+ for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i
+ )
+ if not is_occupied:
+ audio_track_index = i
+ break
+
+ group_id = str(uuid.uuid4())
+ if video_track_index is not None:
+ if video_track_index > self.timeline.num_video_tracks:
+ self.timeline.num_video_tracks = video_track_index
+ video_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, video_track_index, 'video', media_type, group_id)
+ self.timeline.add_clip(video_clip)
+
+ if audio_track_index is not None:
+ if audio_track_index > self.timeline.num_audio_tracks:
+ self.timeline.num_audio_tracks = audio_track_index
+ audio_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, audio_track_index, 'audio', media_type, group_id)
+ self.timeline.add_clip(audio_clip)
+
+ self.status_label.setText(f"Added {len(added_files)} file(s) to timeline.")
+
+ self._perform_complex_timeline_change("Add Media to Timeline", add_clips_action)
+
+ def add_media_files(self):
+ file_paths, _ = QFileDialog.getOpenFileNames(self, "Open Media Files", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)")
+ if file_paths:
+ self._add_media_files_to_project(file_paths)
def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None):
old_state = self._get_current_timeline_state()
@@ -2032,8 +2091,7 @@ def action():
for clip in clips_to_remove:
try: self.timeline.clips.remove(clip)
except ValueError: pass
-
- # Ripple
+
for clip in self.timeline.clips:
if clip.timeline_start_sec >= end_sec:
clip.timeline_start_sec -= duration_to_remove
@@ -2074,7 +2132,7 @@ def export_video(self):
if clip.media_type == 'image':
v_seg = (input_streams[clip.source_path].video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS'))
- else: # video
+ else:
v_seg = (input_streams[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS'))
v_seg = (v_seg.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba'))
From afc567a09cd0efb6fdf79e6bc935339880741674 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 04:04:40 +1100
Subject: [PATCH 107/155] add ability to drag and drop from OS into timeline or
project media
---
videoeditor/main.py | 170 +++++++++++++++++++++++++++++++++++++++++---
1 file changed, 162 insertions(+), 8 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index c4d3e9268..0ff19f672 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -143,6 +143,7 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.drag_over_active = False
self.drag_over_rect = QRectF()
self.drag_over_audio_rect = QRectF()
+ self.drag_url_cache = {}
def set_project_fps(self, fps):
self.project_fps = fps if fps > 0 else 25.0
@@ -782,7 +783,9 @@ def mouseReleaseEvent(self, event: QMouseEvent):
self.update()
def dragEnterEvent(self, event):
- if event.mimeData().hasFormat('application/x-vnd.video.filepath'):
+ if event.mimeData().hasFormat('application/x-vnd.video.filepath') or event.mimeData().hasUrls():
+ if event.mimeData().hasUrls():
+ self.drag_url_cache.clear()
event.acceptProposedAction()
else:
event.ignore()
@@ -791,20 +794,65 @@ def dragLeaveEvent(self, event):
self.drag_over_active = False
self.highlighted_ghost_track_info = None
self.highlighted_track_info = None
+ self.drag_url_cache.clear()
self.update()
def dragMoveEvent(self, event):
mime_data = event.mimeData()
- if not mime_data.hasFormat('application/x-vnd.video.filepath'):
- event.ignore()
- return
+ media_props = None
+ if mime_data.hasUrls():
+ urls = mime_data.urls()
+ if not urls:
+ event.ignore()
+ return
+
+ file_path = urls[0].toLocalFile()
+
+ if file_path in self.drag_url_cache:
+ media_props = self.drag_url_cache[file_path]
+ else:
+ probed_props = self.window()._probe_for_drag(file_path)
+ if probed_props:
+ self.drag_url_cache[file_path] = probed_props
+ media_props = probed_props
+
+ elif mime_data.hasFormat('application/x-vnd.video.filepath'):
+ json_data_bytes = mime_data.data('application/x-vnd.video.filepath').data()
+ media_props = json.loads(json_data_bytes.decode('utf-8'))
+
+ if not media_props:
+ if mime_data.hasUrls():
+ event.acceptProposedAction()
+ pos = event.position()
+ track_info = self.y_to_track_info(pos.y())
+
+ self.drag_over_rect = QRectF()
+ self.drag_over_audio_rect = QRectF()
+ self.drag_over_active = False
+ self.highlighted_ghost_track_info = None
+ self.highlighted_track_info = None
+
+ if track_info:
+ self.drag_over_active = True
+ track_type, track_index = track_info
+ is_ghost_track = (track_type == 'video' and track_index > self.timeline.num_video_tracks) or \
+ (track_type == 'audio' and track_index > self.timeline.num_audio_tracks)
+ if is_ghost_track:
+ self.highlighted_ghost_track_info = track_info
+ else:
+ self.highlighted_track_info = track_info
+
+ self.update()
+ else:
+ event.ignore()
+ return
+
event.acceptProposedAction()
- json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8'))
- duration = json_data['duration']
- media_type = json_data['media_type']
- has_audio = json_data['has_audio']
+ duration = media_props['duration']
+ media_type = media_props['media_type']
+ has_audio = media_props['has_audio']
pos = event.position()
start_sec = self.x_to_sec(pos.x())
@@ -859,9 +907,68 @@ def dropEvent(self, event):
self.drag_over_active = False
self.highlighted_ghost_track_info = None
self.highlighted_track_info = None
+ self.drag_url_cache.clear()
self.update()
mime_data = event.mimeData()
+ if mime_data.hasUrls():
+ file_paths = [url.toLocalFile() for url in mime_data.urls()]
+ main_window = self.window()
+ added_files = main_window._add_media_files_to_project(file_paths)
+ if not added_files:
+ event.ignore()
+ return
+
+ pos = event.position()
+ start_sec = self.x_to_sec(pos.x())
+ track_info = self.y_to_track_info(pos.y())
+ if not track_info:
+ event.ignore()
+ return
+
+ current_timeline_pos = start_sec
+
+ for file_path in added_files:
+ media_info = main_window.media_properties.get(file_path)
+ if not media_info: continue
+
+ duration = media_info['duration']
+ has_audio = media_info['has_audio']
+ media_type = media_info['media_type']
+
+ drop_track_type, drop_track_index = track_info
+ video_track_idx = None
+ audio_track_idx = None
+
+ if media_type == 'image':
+ if drop_track_type == 'video': video_track_idx = drop_track_index
+ elif media_type == 'audio':
+ if drop_track_type == 'audio': audio_track_idx = drop_track_index
+ elif media_type == 'video':
+ if drop_track_type == 'video':
+ video_track_idx = drop_track_index
+ if has_audio: audio_track_idx = 1
+ elif drop_track_type == 'audio' and has_audio:
+ audio_track_idx = drop_track_index
+ video_track_idx = 1
+
+ if video_track_idx is None and audio_track_idx is None:
+ continue
+
+ main_window._add_clip_to_timeline(
+ source_path=file_path,
+ timeline_start_sec=current_timeline_pos,
+ duration_sec=duration,
+ media_type=media_type,
+ clip_start_sec=0,
+ video_track_index=video_track_idx,
+ audio_track_index=audio_track_idx
+ )
+ current_timeline_pos += duration
+
+ event.acceptProposedAction()
+ return
+
if not mime_data.hasFormat('application/x-vnd.video.filepath'):
return
@@ -1033,6 +1140,7 @@ class ProjectMediaWidget(QWidget):
def __init__(self, parent=None):
super().__init__(parent)
self.main_window = parent
+ self.setAcceptDrops(True)
layout = QVBoxLayout(self)
layout.setContentsMargins(2, 2, 2, 2)
@@ -1050,6 +1158,21 @@ def __init__(self, parent=None):
add_button.clicked.connect(self.add_media_requested.emit)
remove_button.clicked.connect(self.remove_selected_media)
+
+ def dragEnterEvent(self, event):
+ if event.mimeData().hasUrls():
+ event.acceptProposedAction()
+ else:
+ event.ignore()
+
+ def dropEvent(self, event):
+ if event.mimeData().hasUrls():
+ file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
+ self.main_window._add_media_files_to_project(file_paths)
+ event.acceptProposedAction()
+ else:
+ event.ignore()
+
def show_context_menu(self, pos):
item = self.media_list.itemAt(pos)
if not item:
@@ -1432,6 +1555,37 @@ def _set_project_properties_from_clip(self, source_path):
except Exception as e: print(f"Could not probe for project properties: {e}")
return False
+ def _probe_for_drag(self, file_path):
+ if file_path in self.media_properties:
+ return self.media_properties[file_path]
+ try:
+ file_ext = os.path.splitext(file_path)[1].lower()
+ media_info = {}
+
+ if file_ext in ['.png', '.jpg', '.jpeg']:
+ media_info['media_type'] = 'image'
+ media_info['duration'] = 5.0
+ media_info['has_audio'] = False
+ else:
+ probe = ffmpeg.probe(file_path)
+ video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
+ audio_stream = next((s for s in probe['streams'] if s['codec_type'] == 'audio'), None)
+
+ if video_stream:
+ media_info['media_type'] = 'video'
+ media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0)))
+ media_info['has_audio'] = audio_stream is not None
+ elif audio_stream:
+ media_info['media_type'] = 'audio'
+ media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0)))
+ media_info['has_audio'] = True
+ else:
+ return None
+ return media_info
+ except Exception as e:
+ print(f"Failed to probe dragged file {os.path.basename(file_path)}: {e}")
+ return None
+
def get_frame_data_at_time(self, time_sec):
clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None)
if not clip_at_time:
From 17567341dc0e5b4b4c77573ccb0176a7ef00b16a Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 04:31:33 +1100
Subject: [PATCH 108/155] fix ai joiner not inserting on new track with new
stream structure
---
videoeditor/main.py | 82 +++++++++++++++++++--------------------------
1 file changed, 34 insertions(+), 48 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 0ff19f672..e9981dc7f 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -1539,25 +1539,11 @@ def _stop_playback_stream(self):
self.playback_process = None
self.playback_clip = None
- def _set_project_properties_from_clip(self, source_path):
- try:
- probe = ffmpeg.probe(source_path)
- video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
- if video_stream:
- self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height'])
- if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0':
- num, den = map(int, video_stream['r_frame_rate'].split('/'))
- if den > 0:
- self.project_fps = num / den
- self.timeline_widget.set_project_fps(self.project_fps)
- print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS")
- return True
- except Exception as e: print(f"Could not probe for project properties: {e}")
- return False
-
- def _probe_for_drag(self, file_path):
+ def _get_media_properties(self, file_path):
+ """Probes a file to get its media properties. Returns a dict or None."""
if file_path in self.media_properties:
return self.media_properties[file_path]
+
try:
file_ext = os.path.splitext(file_path)[1].lower()
media_info = {}
@@ -1581,11 +1567,31 @@ def _probe_for_drag(self, file_path):
media_info['has_audio'] = True
else:
return None
+
return media_info
except Exception as e:
- print(f"Failed to probe dragged file {os.path.basename(file_path)}: {e}")
+ print(f"Failed to probe file {os.path.basename(file_path)}: {e}")
return None
+ def _set_project_properties_from_clip(self, source_path):
+ try:
+ probe = ffmpeg.probe(source_path)
+ video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
+ if video_stream:
+ self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height'])
+ if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0':
+ num, den = map(int, video_stream['r_frame_rate'].split('/'))
+ if den > 0:
+ self.project_fps = num / den
+ self.timeline_widget.set_project_fps(self.project_fps)
+ print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS")
+ return True
+ except Exception as e: print(f"Could not probe for project properties: {e}")
+ return False
+
+ def _probe_for_drag(self, file_path):
+ return self._get_media_properties(file_path)
+
def get_frame_data_at_time(self, time_sec):
clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None)
if not clip_at_time:
@@ -1838,41 +1844,21 @@ def _update_recent_files_menu(self):
self.recent_menu.addAction(action)
def _add_media_to_pool(self, file_path):
- if file_path in self.media_pool: return True
- try:
- self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents()
-
- file_ext = os.path.splitext(file_path)[1].lower()
- media_info = {}
-
- if file_ext in ['.png', '.jpg', '.jpeg']:
- media_info['media_type'] = 'image'
- media_info['duration'] = 5.0
- media_info['has_audio'] = False
- else:
- probe = ffmpeg.probe(file_path)
- video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
- audio_stream = next((s for s in probe['streams'] if s['codec_type'] == 'audio'), None)
-
- if video_stream:
- media_info['media_type'] = 'video'
- media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0)))
- media_info['has_audio'] = audio_stream is not None
- elif audio_stream:
- media_info['media_type'] = 'audio'
- media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0)))
- media_info['has_audio'] = True
- else:
- raise ValueError("No video or audio stream found.")
-
+ if file_path in self.media_pool:
+ return True
+
+ self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents()
+
+ media_info = self._get_media_properties(file_path)
+
+ if media_info:
self.media_properties[file_path] = media_info
self.media_pool.append(file_path)
self.project_media_widget.add_media_item(file_path)
-
self.status_label.setText(f"Added {os.path.basename(file_path)} to project.")
return True
- except Exception as e:
- self.status_label.setText(f"Error probing file: {e}")
+ else:
+ self.status_label.setText(f"Error probing file: {os.path.basename(file_path)}")
return False
def on_media_removed_from_pool(self, file_path):
From a223c1e8df4d0104a84965db09ae789c35123da1 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 04:32:01 +1100
Subject: [PATCH 109/155] fix ai joiner not inserting on new track with new
stream structure (2)
---
videoeditor/plugins/ai_frame_joiner/main.py | 34 ++++++++++++++-------
1 file changed, 23 insertions(+), 11 deletions(-)
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index d66dda353..85fd7a4d3 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -5,6 +5,7 @@
import requests
from pathlib import Path
import ffmpeg
+import uuid
from PyQt6.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QFormLayout,
@@ -268,6 +269,8 @@ def setup_generator_for_region(self, region, on_new_track=False):
self.dock_widget.raise_()
def insert_generated_clip(self, video_path):
+ from main import TimelineClip
+
if not self.active_region:
self.update_main_status("AI Joiner Error: No active region to insert into.")
return
@@ -277,25 +280,29 @@ def insert_generated_clip(self, video_path):
return
start_sec, end_sec = self.active_region
+ is_new_track_mode = self.insert_on_new_track
+
self.update_main_status(f"AI Joiner: Inserting clip {os.path.basename(video_path)}")
- try:
+ def complex_insertion_action():
probe = ffmpeg.probe(video_path)
actual_duration = float(probe['format']['duration'])
- if self.insert_on_new_track:
- self.app.add_track('video')
+ if is_new_track_mode:
+ self.app.timeline.num_video_tracks += 1
new_track_index = self.app.timeline.num_video_tracks
- self.app._add_clip_to_timeline(
+
+ new_clip = TimelineClip(
source_path=video_path,
timeline_start_sec=start_sec,
clip_start_sec=0,
duration_sec=actual_duration,
+ track_index=new_track_index,
+ track_type='video',
media_type='video',
- video_track_index=new_track_index,
- audio_track_index=None
+ group_id=str(uuid.uuid4())
)
- self.update_main_status("AI clip inserted successfully on new track.")
+ self.app.timeline.add_clip(new_clip)
else:
for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
@@ -308,18 +315,23 @@ def insert_generated_clip(self, video_path):
if clip in self.app.timeline.clips:
self.app.timeline.clips.remove(clip)
- self.app._add_clip_to_timeline(
+ new_clip = TimelineClip(
source_path=video_path,
timeline_start_sec=start_sec,
clip_start_sec=0,
duration_sec=actual_duration,
+ track_index=1,
+ track_type='video',
media_type='video',
- video_track_index=1,
- audio_track_index=None
+ group_id=str(uuid.uuid4())
)
- self.update_main_status("AI clip inserted successfully.")
+ self.app.timeline.add_clip(new_clip)
+ try:
+ self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action)
+
self.app.prune_empty_tracks()
+ self.update_main_status("AI clip inserted successfully.")
except Exception as e:
error_message = f"AI Joiner Error during clip insertion/probing: {e}"
From 8472a8a92a1b7cff5168f5df9f0873dca2a3703e Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 06:05:18 +1100
Subject: [PATCH 110/155] add Create mode for AI plusin
---
main.py | 46 ++--
videoeditor/main.py | 17 +-
videoeditor/plugins/ai_frame_joiner/main.py | 239 +++++++++++++++++---
3 files changed, 238 insertions(+), 64 deletions(-)
diff --git a/main.py b/main.py
index 0dff699e6..773bcc9fa 100644
--- a/main.py
+++ b/main.py
@@ -312,56 +312,49 @@ def __init__(self):
self.apply_initial_config()
self.connect_signals()
self.init_wgp_state()
-
- # --- CHANGE: Removed setModelSignal as it's no longer needed ---
+
self.api_bridge.generateSignal.connect(self._api_generate)
@pyqtSlot(str)
def _api_set_model(self, model_type):
"""This slot is executed in the main GUI thread."""
if not model_type or not wgp: return
-
- # 1. Check if the model is valid by looking it up in the master model definition dictionary.
+
if model_type not in wgp.models_def:
print(f"API Error: Model type '{model_type}' is not a valid model.")
return
- # 2. Check if already selected to avoid unnecessary UI refreshes.
if self.state.get('model_type') == model_type:
print(f"API: Model is already set to {model_type}.")
return
- # 3. Redraw all model dropdowns to ensure the correct hierarchy is displayed
- # and the target model is selected. This function handles finding the
- # correct Family and Base model for the given finetune model_type.
self.update_model_dropdowns(model_type)
-
- # 4. Manually trigger the logic that normally runs when the user selects a model.
- # This is necessary because update_model_dropdowns blocks signals.
+
self._on_model_changed()
- # 5. Final check to see if the model was actually set.
if self.state.get('model_type') == model_type:
print(f"API: Successfully set model to {model_type}.")
else:
- # This could happen if update_model_dropdowns silently fails to find the model.
print(f"API Error: Failed to set model to '{model_type}'. The model might be hidden by your current configuration.")
-
- # --- CHANGE: Slot now accepts model_type and duration_sec, and calculates frame count ---
+
@pyqtSlot(object, object, object, object, bool)
def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_generation):
"""This slot is executed in the main GUI thread."""
- # 1. Set model if a new one is provided
if model_type:
self._api_set_model(model_type)
-
if start_frame:
self.widgets['mode_s'].setChecked(True)
self.widgets['image_start'].setText(start_frame)
+ else:
+ self.widgets['mode_t'].setChecked(True)
+ self.widgets['image_start'].clear()
if end_frame:
self.widgets['image_end_checkbox'].setChecked(True)
self.widgets['image_end'].setText(end_frame)
+ else:
+ self.widgets['image_end_checkbox'].setChecked(False)
+ self.widgets['image_end'].clear()
if duration_sec is not None:
try:
@@ -1707,7 +1700,7 @@ def _add_task_to_queue_and_update_ui(self):
def _add_task_to_queue(self):
queue_size_before = len(self.state["gen"]["queue"])
all_inputs = self.collect_inputs()
- keys_to_remove = ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality']
+ keys_to_remove = ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']
for key in keys_to_remove:
all_inputs.pop(key, None)
@@ -1938,7 +1931,6 @@ def run_api_server():
api_server.run(port=5100, host='127.0.0.1', debug=False)
if FLASK_AVAILABLE:
- # --- CHANGE: Removed /api/set_model endpoint ---
@api_server.route('/api/generate', methods=['POST'])
def generate():
@@ -1964,13 +1956,13 @@ def generate():
return jsonify({"message": "Parameters set without starting generation."})
- @api_server.route('/api/latest_output', methods=['GET'])
- def get_latest_output():
+ @api_server.route('/api/outputs', methods=['GET'])
+ def get_outputs():
if not main_window_instance:
return jsonify({"error": "Application not ready"}), 503
-
- path = main_window_instance.latest_output_path
- return jsonify({"latest_output_path": path})
+
+ file_list = main_window_instance.state.get('gen', {}).get('file_list', [])
+ return jsonify({"outputs": file_list})
# =====================================================================
# --- END OF API SERVER ADDITION ---
@@ -1978,13 +1970,11 @@ def get_latest_output():
if __name__ == '__main__':
app = QApplication(sys.argv)
-
- # Create and show the main window
+
window = MainWindow()
- main_window_instance = window # Assign to global for API access
+ main_window_instance = window
window.show()
- # Start the Flask API server in a separate thread
if FLASK_AVAILABLE:
api_thread = threading.Thread(target=run_api_server, daemon=True)
api_thread.start()
diff --git a/videoeditor/main.py b/videoeditor/main.py
index e9981dc7f..d1bd5ce8d 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -139,12 +139,18 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.video_tracks_y_start = 0
self.audio_tracks_y_start = 0
-
+ self.hover_preview_rect = None
+ self.hover_preview_audio_rect = None
self.drag_over_active = False
self.drag_over_rect = QRectF()
self.drag_over_audio_rect = QRectF()
self.drag_url_cache = {}
+ def set_hover_preview_rects(self, video_rect, audio_rect):
+ self.hover_preview_rect = video_rect
+ self.hover_preview_audio_rect = audio_rect
+ self.update()
+
def set_project_fps(self, fps):
self.project_fps = fps if fps > 0 else 25.0
self.max_pixels_per_second = self.project_fps * 20
@@ -177,6 +183,15 @@ def paintEvent(self, event):
painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80))
painter.drawRect(self.drag_over_audio_rect)
+ if self.hover_preview_rect:
+ painter.setPen(QPen(QColor(0, 255, 255, 180), 2, Qt.PenStyle.DashLine))
+ painter.fillRect(self.hover_preview_rect, QColor(0, 255, 255, 60))
+ painter.drawRect(self.hover_preview_rect)
+ if self.hover_preview_audio_rect:
+ painter.setPen(QPen(QColor(0, 255, 255, 180), 2, Qt.PenStyle.DashLine))
+ painter.fillRect(self.hover_preview_audio_rect, QColor(0, 255, 255, 60))
+ painter.drawRect(self.hover_preview_audio_rect)
+
self.draw_playhead(painter)
total_width = self.sec_to_x(self.timeline.get_total_duration()) + 200
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
index 85fd7a4d3..410333d28 100644
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ b/videoeditor/plugins/ai_frame_joiner/main.py
@@ -9,33 +9,134 @@
from PyQt6.QtWidgets import (
QWidget, QVBoxLayout, QHBoxLayout, QFormLayout,
- QLineEdit, QPushButton, QLabel, QMessageBox, QCheckBox
+ QLineEdit, QPushButton, QLabel, QMessageBox, QCheckBox, QListWidget,
+ QListWidgetItem, QGroupBox
)
from PyQt6.QtGui import QImage, QPixmap
-from PyQt6.QtCore import QTimer, pyqtSignal, Qt, QSize
+from PyQt6.QtCore import pyqtSignal, Qt, QSize, QRectF, QUrl, QTimer
+from PyQt6.QtMultimedia import QMediaPlayer
+from PyQt6.QtMultimediaWidgets import QVideoWidget
sys.path.append(str(Path(__file__).parent.parent.parent))
from plugins import VideoEditorPlugin
-
API_BASE_URL = "http://127.0.0.1:5100"
+class VideoResultItemWidget(QWidget):
+ def __init__(self, video_path, plugin, parent=None):
+ super().__init__(parent)
+ self.video_path = video_path
+ self.plugin = plugin
+ self.app = plugin.app
+ self.duration = 0.0
+ self.has_audio = False
+
+ self.setMinimumSize(200, 180)
+ self.setMaximumHeight(190)
+
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(5, 5, 5, 5)
+ layout.setSpacing(5)
+
+ self.media_player = QMediaPlayer()
+ self.video_widget = QVideoWidget()
+ self.video_widget.setFixedSize(160, 90)
+ self.media_player.setVideoOutput(self.video_widget)
+ self.media_player.setSource(QUrl.fromLocalFile(self.video_path))
+ self.media_player.setLoops(QMediaPlayer.Loops.Infinite)
+
+ self.info_label = QLabel(os.path.basename(video_path))
+ self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.info_label.setWordWrap(True)
+
+ self.insert_button = QPushButton("Insert into Timeline")
+ self.insert_button.clicked.connect(self.on_insert)
+
+ h_layout = QHBoxLayout()
+ h_layout.addStretch()
+ h_layout.addWidget(self.video_widget)
+ h_layout.addStretch()
+
+ layout.addLayout(h_layout)
+ layout.addWidget(self.info_label)
+ layout.addWidget(self.insert_button)
+
+ self.probe_video()
+
+ def probe_video(self):
+ try:
+ probe = ffmpeg.probe(self.video_path)
+ self.duration = float(probe['format']['duration'])
+ self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', []))
+
+ self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)")
+
+ except Exception as e:
+ self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}")
+ print(f"Error probing video {self.video_path}: {e}")
+
+ def enterEvent(self, event):
+ super().enterEvent(event)
+ self.media_player.play()
+
+ if not self.plugin.active_region or self.duration == 0:
+ return
+
+ start_sec, _ = self.plugin.active_region
+ timeline = self.app.timeline_widget
+
+ video_rect, audio_rect = None, None
+
+ x = timeline.sec_to_x(start_sec)
+ w = int(self.duration * timeline.pixels_per_second)
+
+ if self.plugin.insert_on_new_track:
+ video_y = timeline.TIMESCALE_HEIGHT
+ video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
+ if self.has_audio:
+ audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT
+ audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
+ else:
+ v_track_idx = 1
+ visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx
+ video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT
+ video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
+ if self.has_audio:
+ a_track_idx = 1
+ visual_a_idx = a_track_idx - 1
+ audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT
+ audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
+
+ timeline.set_hover_preview_rects(video_rect, audio_rect)
+
+ def leaveEvent(self, event):
+ super().leaveEvent(event)
+ self.media_player.pause()
+ self.media_player.setPosition(0)
+ self.app.timeline_widget.set_hover_preview_rects(None, None)
+
+ def on_insert(self):
+ self.media_player.stop()
+ self.media_player.setSource(QUrl())
+ self.media_player.setVideoOutput(None)
+ self.app.timeline_widget.set_hover_preview_rects(None, None)
+ self.plugin.insert_generated_clip(self.video_path)
+
class WgpClientWidget(QWidget):
- generation_complete = pyqtSignal(str)
status_updated = pyqtSignal(str)
- def __init__(self):
+ def __init__(self, plugin):
super().__init__()
-
- self.last_known_output = None
+ self.plugin = plugin
+ self.processed_files = set()
layout = QVBoxLayout(self)
form_layout = QFormLayout()
- self.model_input = QLineEdit()
- self.model_input.setPlaceholderText("Optional (default: i2v_2_2)")
-
- previews_layout = QHBoxLayout()
+ self.previews_widget = QWidget()
+ previews_layout = QHBoxLayout(self.previews_widget)
+ previews_layout.setContentsMargins(0, 0, 0, 0)
+
start_preview_layout = QVBoxLayout()
start_preview_layout.addWidget(QLabel("Start Frame"))
self.start_frame_preview = QLabel("N/A")
@@ -63,12 +164,24 @@ def __init__(self):
self.generate_button = QPushButton("Generate")
- layout.addLayout(previews_layout)
- form_layout.addRow("Model Type:", self.model_input)
+ layout.addWidget(self.previews_widget)
form_layout.addRow(self.autostart_checkbox)
form_layout.addRow(self.generate_button)
layout.addLayout(form_layout)
+
+ # --- Results Area ---
+ results_group = QGroupBox("Generated Clips (Hover to play, Click button to insert)")
+ results_layout = QVBoxLayout()
+ self.results_list = QListWidget()
+ self.results_list.setFlow(QListWidget.Flow.LeftToRight)
+ self.results_list.setWrapping(True)
+ self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust)
+ self.results_list.setSpacing(10)
+ results_layout.addWidget(self.results_list)
+ results_group.setLayout(results_layout)
+ layout.addWidget(results_group)
+
layout.addStretch()
self.generate_button.clicked.connect(self.generate)
@@ -106,7 +219,7 @@ def handle_api_error(self, response, action="performing action"):
def check_server_status(self):
try:
- requests.get(f"{API_BASE_URL}/api/latest_output", timeout=1)
+ requests.get(f"{API_BASE_URL}/api/outputs", timeout=1)
self.status_updated.emit("AI Joiner: Connected to WanGP server.")
return True
except requests.exceptions.ConnectionError:
@@ -114,11 +227,10 @@ def check_server_status(self):
QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.")
return False
-
def generate(self):
payload = {}
- model_type = self.model_input.text().strip()
+ model_type = self.model_type_to_use
if model_type:
payload['model_type'] = model_type
@@ -148,31 +260,44 @@ def generate(self):
def poll_for_output(self):
try:
- response = requests.get(f"{API_BASE_URL}/api/latest_output")
+ response = requests.get(f"{API_BASE_URL}/api/outputs")
if response.status_code == 200:
data = response.json()
- latest_path = data.get("latest_output_path")
+ output_files = data.get("outputs", [])
- if latest_path and latest_path != self.last_known_output:
- self.stop_polling()
- self.last_known_output = latest_path
- self.status_updated.emit(f"AI Joiner: New output received! Inserting clip...")
- self.generation_complete.emit(latest_path)
+ new_files = set(output_files) - self.processed_files
+ if new_files:
+ self.status_updated.emit(f"AI Joiner: Received {len(new_files)} new clip(s).")
+ for file_path in sorted(list(new_files)):
+ self.add_result_item(file_path)
+ self.processed_files.add(file_path)
+ else:
+ self.status_updated.emit("AI Joiner: Polling for output...")
else:
self.status_updated.emit("AI Joiner: Polling for output...")
except requests.exceptions.RequestException:
self.status_updated.emit("AI Joiner: Polling... (Connection issue)")
+
+ def add_result_item(self, video_path):
+ item_widget = VideoResultItemWidget(video_path, self.plugin)
+ list_item = QListWidgetItem(self.results_list)
+ list_item.setSizeHint(item_widget.sizeHint())
+ self.results_list.addItem(list_item)
+ self.results_list.setItemWidget(list_item, item_widget)
+
+ def clear_results(self):
+ self.results_list.clear()
+ self.processed_files.clear()
class Plugin(VideoEditorPlugin):
def initialize(self):
self.name = "AI Frame Joiner"
self.description = "Uses a local AI server to generate a video between two frames."
- self.client_widget = WgpClientWidget()
+ self.client_widget = WgpClientWidget(self)
self.dock_widget = None
self.active_region = None
self.temp_dir = None
self.insert_on_new_track = False
- self.client_widget.generation_complete.connect(self.insert_generated_clip)
self.client_widget.status_updated.connect(self.update_main_status)
def enable(self):
@@ -202,24 +327,41 @@ def _cleanup_temp_dir(self):
def _reset_state(self):
self.active_region = None
self.insert_on_new_track = False
+ self.client_widget.stop_polling()
+ self.client_widget.clear_results()
self._cleanup_temp_dir()
self.update_main_status("AI Joiner: Idle")
self.client_widget.set_previews(None, None)
+ self.client_widget.model_type_to_use = None
def on_timeline_context_menu(self, menu, event):
region = self.app.timeline_widget.get_region_at_pos(event.pos())
if region:
menu.addSeparator()
- action = menu.addAction("Join Frames With AI")
- action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
-
- action_new_track = menu.addAction("Join Frames With AI (New Track)")
- action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
+
+ start_sec, end_sec = region
+ start_data, _, _ = self.app.get_frame_data_at_time(start_sec)
+ end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+
+ if start_data and end_data:
+ join_action = menu.addAction("Join Frames With AI")
+ join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
+
+ join_action_new_track = menu.addAction("Join Frames With AI (New Track)")
+ join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
+
+ create_action = menu.addAction("Create Frames With AI")
+ create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False))
+
+ create_action_new_track = menu.addAction("Create Frames With AI (New Track)")
+ create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True))
def setup_generator_for_region(self, region, on_new_track=False):
self._reset_state()
+ self.client_widget.model_type_to_use = "i2v_2_2"
self.active_region = region
self.insert_on_new_track = on_new_track
+ self.client_widget.previews_widget.setVisible(True)
start_sec, end_sec = region
if not self.client_widget.check_server_status():
@@ -268,6 +410,30 @@ def setup_generator_for_region(self, region, on_new_track=False):
self.dock_widget.show()
self.dock_widget.raise_()
+ def setup_creator_for_region(self, region, on_new_track=False):
+ self._reset_state()
+ self.client_widget.model_type_to_use = "t2v_2_2"
+ self.active_region = region
+ self.insert_on_new_track = on_new_track
+ self.client_widget.previews_widget.setVisible(False)
+
+ start_sec, end_sec = region
+ if not self.client_widget.check_server_status():
+ return
+
+ self.client_widget.set_previews(None, None)
+
+ duration_sec = end_sec - start_sec
+ self.client_widget.duration_input.setText(str(duration_sec))
+
+ self.client_widget.start_frame_input.clear()
+ self.client_widget.end_frame_input.clear()
+
+ self.update_main_status(f"AI Creator: Ready for region {start_sec:.2f}s - {end_sec:.2f}s")
+
+ self.dock_widget.show()
+ self.dock_widget.raise_()
+
def insert_generated_clip(self, video_path):
from main import TimelineClip
@@ -333,11 +499,14 @@ def complex_insertion_action():
self.app.prune_empty_tracks()
self.update_main_status("AI clip inserted successfully.")
+ for i in range(self.client_widget.results_list.count()):
+ item = self.client_widget.results_list.item(i)
+ widget = self.client_widget.results_list.itemWidget(item)
+ if widget and widget.video_path == video_path:
+ self.client_widget.results_list.takeItem(i)
+ break
+
except Exception as e:
error_message = f"AI Joiner Error during clip insertion/probing: {e}"
self.update_main_status(error_message)
- print(error_message)
-
- finally:
- self._cleanup_temp_dir()
- self.active_region = None
\ No newline at end of file
+ print(error_message)
\ No newline at end of file
From deef0107e600743d060abb2a42b01f445eba2f81 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 06:38:04 +1100
Subject: [PATCH 111/155] add snapping between clips for resize
---
videoeditor/main.py | 31 +++++++++++++++++++++----------
1 file changed, 21 insertions(+), 10 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index d1bd5ce8d..8a5418ba8 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -574,9 +574,15 @@ def mouseMoveEvent(self, event: QMouseEvent):
delta_x = event.pos().x() - self.resize_start_pos.x()
time_delta = delta_x / self.pixels_per_second
min_duration = 1.0 / self.project_fps
- playhead_time = self.playhead_pos_sec
snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second
+ snap_points = [self.playhead_pos_sec]
+ for clip in self.timeline.clips:
+ if clip.id == self.resizing_clip.id: continue
+ if linked_clip and clip.id == linked_clip.id: continue
+ snap_points.append(clip.timeline_start_sec)
+ snap_points.append(clip.timeline_end_sec)
+
media_props = self.window().media_properties.get(self.resizing_clip.source_path)
source_duration = media_props['duration'] if media_props else float('inf')
@@ -585,10 +591,12 @@ def mouseMoveEvent(self, event: QMouseEvent):
original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec
original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec
true_new_start_sec = original_start + time_delta
- if abs(true_new_start_sec - playhead_time) < snap_time_delta:
- new_start_sec = playhead_time
- else:
- new_start_sec = true_new_start_sec
+
+ new_start_sec = true_new_start_sec
+ for snap_point in snap_points:
+ if abs(true_new_start_sec - snap_point) < snap_time_delta:
+ new_start_sec = snap_point
+ break
if new_start_sec > original_start + original_duration - min_duration:
new_start_sec = original_start + original_duration - min_duration
@@ -621,11 +629,14 @@ def mouseMoveEvent(self, event: QMouseEvent):
true_new_duration = original_duration + time_delta
true_new_end_time = original_start + true_new_duration
-
- if abs(true_new_end_time - playhead_time) < snap_time_delta:
- new_duration = playhead_time - original_start
- else:
- new_duration = true_new_duration
+
+ new_end_time = true_new_end_time
+ for snap_point in snap_points:
+ if abs(true_new_end_time - snap_point) < snap_time_delta:
+ new_end_time = snap_point
+ break
+
+ new_duration = new_end_time - original_start
if new_duration < min_duration:
new_duration = min_duration
From da346301cbdfb94ff690c7b1c4ea80abb29781ae Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 07:28:48 +1100
Subject: [PATCH 112/155] fixed minimize closing all windows
---
videoeditor/main.py | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 8a5418ba8..90e36fad5 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -1482,6 +1482,10 @@ def remove_track(self, track_type):
command.undo()
self.undo_stack.push(command)
+ def on_dock_visibility_changed(self, action, visible):
+ if self.isMinimized():
+ return
+ action.setChecked(visible)
def _create_menu_bar(self):
menu_bar = self.menuBar()
@@ -1533,7 +1537,7 @@ def _create_menu_bar(self):
action = QAction(data['name'], self, checkable=True)
if hasattr(data['widget'], 'visibilityChanged'):
action.toggled.connect(data['widget'].setVisible)
- data['widget'].visibilityChanged.connect(action.setChecked)
+ data['widget'].visibilityChanged.connect(lambda visible, a=action: self.on_dock_visibility_changed(a, visible))
else:
action.toggled.connect(lambda checked, k=key: self.toggle_widget_visibility(k, checked))
@@ -1766,7 +1770,6 @@ def _save_settings(self):
def _apply_loaded_settings(self):
visibility_settings = self.settings.get("window_visibility", {})
for key, data in self.managed_widgets.items():
- if data.get('plugin'): continue
is_visible = visibility_settings.get(key, False if key == 'project_media' else True)
if data['widget'] is not self.preview_widget:
data['widget'].setVisible(is_visible)
@@ -2373,7 +2376,7 @@ def add_dock_widget(self, plugin_instance, widget, title, area=Qt.DockWidgetArea
dock.setVisible(initial_visibility)
action = QAction(title, self, checkable=True)
action.toggled.connect(dock.setVisible)
- dock.visibilityChanged.connect(action.setChecked)
+ dock.visibilityChanged.connect(lambda visible, a=action: self.on_dock_visibility_changed(a, visible))
action.setChecked(dock.isVisible())
self.windows_menu.addAction(action)
self.managed_widgets[widget_key] = {'widget': dock, 'name': title, 'action': action, 'plugin': plugin_instance.name}
From 8b51fc48d8d26d3a2ab6c80ec7119e5a37a7e503 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 07:33:49 +1100
Subject: [PATCH 113/155] fixed export with images and blank canvas
---
videoeditor/main.py | 75 ++++++++++++++++++++++++++-------------------
1 file changed, 43 insertions(+), 32 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 90e36fad5..d1cda9fa9 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -2277,40 +2277,51 @@ def export_video(self):
w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration()
sample_rate, channel_layout = '44100', 'stereo'
-
- input_streams = {}
- last_video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi')
- for i in range(self.timeline.num_video_tracks):
- track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec)
- if not track_clips: continue
-
- track_segments = []
- last_end = track_clips[0].timeline_start_sec
-
- for clip in track_clips:
- gap = clip.timeline_start_sec - last_end
- if gap > 0.01:
- track_segments.append(ffmpeg.input(f'color=c=black@0.0:s={w}x{h}:r={fr_str}:d={gap}', f='lavfi').filter('format', pix_fmts='rgba'))
-
- if clip.source_path not in input_streams:
- if clip.media_type == 'image':
- input_streams[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps)
- else:
- input_streams[clip.source_path] = ffmpeg.input(clip.source_path)
+ video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi')
+
+ all_video_clips = sorted(
+ [c for c in self.timeline.clips if c.track_type == 'video'],
+ key=lambda c: c.track_index
+ )
+ input_nodes = {}
+
+ for clip in all_video_clips:
+ if clip.source_path not in input_nodes:
if clip.media_type == 'image':
- v_seg = (input_streams[clip.source_path].video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS'))
+ input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps)
else:
- v_seg = (input_streams[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS'))
-
- v_seg = (v_seg.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba'))
- track_segments.append(v_seg)
- last_end = clip.timeline_end_sec
+ input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
- track_stream = ffmpeg.concat(*track_segments, v=1, a=0).filter('setpts', f'PTS-STARTPTS+{track_clips[0].timeline_start_sec}/TB')
- last_video_stream = ffmpeg.overlay(last_video_stream, track_stream)
- final_video = last_video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps)
+ clip_source_node = input_nodes[clip.source_path]
+
+ if clip.media_type == 'image':
+ segment_stream = (
+ clip_source_node.video
+ .trim(duration=clip.duration_sec)
+ .setpts('PTS-STARTPTS')
+ )
+ else:
+ segment_stream = (
+ clip_source_node.video
+ .trim(start=clip.clip_start_sec, duration=clip.duration_sec)
+ .setpts('PTS-STARTPTS')
+ )
+
+ processed_segment = (
+ segment_stream
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ )
+
+ video_stream = ffmpeg.overlay(
+ video_stream,
+ processed_segment,
+ enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})'
+ )
+
+ final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps)
track_audio_streams = []
for i in range(self.timeline.num_audio_tracks):
@@ -2321,14 +2332,14 @@ def export_video(self):
last_end = track_clips[0].timeline_start_sec
for clip in track_clips:
- if clip.source_path not in input_streams:
- input_streams[clip.source_path] = ffmpeg.input(clip.source_path)
+ if clip.source_path not in input_nodes:
+ input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
gap = clip.timeline_start_sec - last_end
if gap > 0.01:
track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi'))
- a_seg = input_streams[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS')
+ a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS')
track_segments.append(a_seg)
last_end = clip.timeline_end_sec
From 836a172ff28b1397b3f42f39cf7fc0469058fb93 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 08:03:38 +1100
Subject: [PATCH 114/155] fix multi-track preview rendering
---
videoeditor/main.py | 179 ++++++++++++++++++++++++++++++--------------
1 file changed, 124 insertions(+), 55 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index d1cda9fa9..9cfca35b0 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -1258,7 +1258,6 @@ def __init__(self, project_to_load=None):
self.project_height = 720
self.playback_timer = QTimer(self)
self.playback_process = None
- self.playback_clip = None
self._setup_ui()
self._connect_signals()
@@ -1546,19 +1545,64 @@ def _create_menu_bar(self):
def _start_playback_stream_at(self, time_sec):
self._stop_playback_stream()
- clip = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None)
- if not clip: return
- self.playback_clip = clip
- clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec
- w, h = self.project_width, self.project_height
+ if not self.timeline.clips:
+ return
+
+ w, h, fr = self.project_width, self.project_height, self.project_fps
+ total_dur = self.timeline.get_total_duration()
+
+ video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr}:d={total_dur}', f='lavfi')
+
+ all_video_clips = sorted(
+ [c for c in self.timeline.clips if c.track_type == 'video'],
+ key=lambda c: c.track_index
+ )
+
+ input_nodes = {}
+ for clip in all_video_clips:
+ if clip.source_path not in input_nodes:
+ if clip.media_type == 'image':
+ input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=fr)
+ else:
+ input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
+
+ clip_source_node = input_nodes[clip.source_path]
+
+ if clip.media_type == 'image':
+ segment_stream = (
+ clip_source_node.video
+ .trim(duration=clip.duration_sec)
+ .setpts('PTS-STARTPTS')
+ )
+ else:
+ segment_stream = (
+ clip_source_node.video
+ .trim(start=clip.clip_start_sec, duration=clip.duration_sec)
+ .setpts('PTS-STARTPTS')
+ )
+
+ processed_segment = (
+ segment_stream
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ )
+
+ video_stream = ffmpeg.overlay(
+ video_stream,
+ processed_segment,
+ enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})'
+ )
+
try:
- args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time)
- .filter('scale', w, h, force_original_aspect_ratio='decrease')
- .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
- .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile())
+ args = (
+ video_stream
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=fr, ss=time_sec)
+ .compile()
+ )
self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
except Exception as e:
- print(f"Failed to start playback stream: {e}"); self._stop_playback_stream()
+ print(f"Failed to start playback stream: {e}")
+ self._stop_playback_stream()
def _stop_playback_stream(self):
if self.playback_process:
@@ -1567,7 +1611,6 @@ def _stop_playback_stream(self):
try: self.playback_process.wait(timeout=0.5)
except subprocess.TimeoutExpired: self.playback_process.kill(); self.playback_process.wait()
self.playback_process = None
- self.playback_clip = None
def _get_media_properties(self, file_path):
"""Probes a file to get its media properties. Returns a dict or None."""
@@ -1623,25 +1666,33 @@ def _probe_for_drag(self, file_path):
return self._get_media_properties(file_path)
def get_frame_data_at_time(self, time_sec):
- clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None)
- if not clip_at_time:
+ clips_at_time = [
+ c for c in self.timeline.clips
+ if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec
+ ]
+
+ if not clips_at_time:
return (None, 0, 0)
+
+ clips_at_time.sort(key=lambda c: c.track_index, reverse=True)
+ clip_to_render = clips_at_time[0]
+
try:
w, h = self.project_width, self.project_height
- if clip_at_time.media_type == 'image':
+ if clip_to_render.media_type == 'image':
out, _ = (
ffmpeg
- .input(clip_at_time.source_path)
+ .input(clip_to_render.source_path)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
else:
- clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
+ clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec
out, _ = (
ffmpeg
- .input(clip_at_time.source_path, ss=clip_time)
+ .input(clip_to_render.source_path, ss=clip_time)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
@@ -1653,22 +1704,33 @@ def get_frame_data_at_time(self, time_sec):
return (None, 0, 0)
def get_frame_at_time(self, time_sec):
- clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None)
- black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
- if not clip_at_time: return black_pixmap
+ clips_at_time = [
+ c for c in self.timeline.clips
+ if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec
+ ]
+
+ black_pixmap = QPixmap(self.project_width, self.project_height)
+ black_pixmap.fill(QColor("black"))
+
+ if not clips_at_time:
+ return black_pixmap
+
+ clips_at_time.sort(key=lambda c: c.track_index, reverse=True)
+ clip_to_render = clips_at_time[0]
+
try:
w, h = self.project_width, self.project_height
- if clip_at_time.media_type == 'image':
+ if clip_to_render.media_type == 'image':
out, _ = (
- ffmpeg.input(clip_at_time.source_path)
+ ffmpeg.input(clip_to_render.source_path)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
else:
- clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
- out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time)
+ clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec
+ out, _ = (ffmpeg.input(clip_to_render.source_path, ss=clip_time)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
@@ -1676,10 +1738,16 @@ def get_frame_at_time(self, time_sec):
image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888)
return QPixmap.fromImage(image)
- except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap
+ except ffmpeg.Error as e:
+ print(f"Error extracting frame: {e.stderr}")
+ return black_pixmap
def seek_preview(self, time_sec):
- self._stop_playback_stream()
+ if self.playback_timer.isActive():
+ self.playback_timer.stop()
+ self._stop_playback_stream()
+ self.play_pause_button.setText("Play")
+
self.timeline_widget.playhead_pos_sec = time_sec
self.timeline_widget.update()
frame_pixmap = self.get_frame_at_time(time_sec)
@@ -1688,16 +1756,35 @@ def seek_preview(self, time_sec):
self.preview_widget.setPixmap(scaled_pixmap)
def toggle_playback(self):
- if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play")
+ if self.playback_timer.isActive():
+ self.playback_timer.stop()
+ self._stop_playback_stream()
+ self.play_pause_button.setText("Play")
else:
if not self.timeline.clips: return
- if self.timeline_widget.playhead_pos_sec >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_sec = 0.0
- self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause")
+
+ playhead_time = self.timeline_widget.playhead_pos_sec
+ if playhead_time >= self.timeline.get_total_duration():
+ playhead_time = 0.0
+ self.timeline_widget.playhead_pos_sec = 0.0
+
+ self._start_playback_stream_at(playhead_time)
+
+ if self.playback_process:
+ self.playback_timer.start(int(1000 / self.project_fps))
+ self.play_pause_button.setText("Pause")
+
+ def stop_playback(self):
+ self.playback_timer.stop()
+ self._stop_playback_stream()
+ self.play_pause_button.setText("Play")
+ self.seek_preview(0.0)
- def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0.0)
def step_frame(self, direction):
if not self.timeline.clips: return
- self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream()
+ self.playback_timer.stop()
+ self._stop_playback_stream()
+ self.play_pause_button.setText("Play")
frame_duration = 1.0 / self.project_fps
new_time = self.timeline_widget.playhead_pos_sec + (direction * frame_duration)
self.seek_preview(max(0, min(new_time, self.timeline.get_total_duration())))
@@ -1705,33 +1792,13 @@ def step_frame(self, direction):
def advance_playback_frame(self):
frame_duration = 1.0 / self.project_fps
new_time = self.timeline_widget.playhead_pos_sec + frame_duration
- if new_time > self.timeline.get_total_duration(): self.stop_playback(); return
+ if new_time > self.timeline.get_total_duration():
+ self.stop_playback()
+ return
self.timeline_widget.playhead_pos_sec = new_time
self.timeline_widget.update()
- clip_at_new_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec), None)
-
- if not clip_at_new_time:
- self._stop_playback_stream()
- black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
- scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
- self.preview_widget.setPixmap(scaled_pixmap)
- return
-
- if clip_at_new_time.media_type == 'image':
- if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
- self._stop_playback_stream()
- self.playback_clip = clip_at_new_time
- frame_pixmap = self.get_frame_at_time(new_time)
- if frame_pixmap:
- scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
- self.preview_widget.setPixmap(scaled_pixmap)
- return
-
- if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
- self._start_playback_stream_at(new_time)
-
if self.playback_process:
frame_size = self.project_width * self.project_height * 3
frame_bytes = self.playback_process.stdout.read(frame_size)
@@ -1742,6 +1809,8 @@ def advance_playback_frame(self):
self.preview_widget.setPixmap(scaled_pixmap)
else:
self._stop_playback_stream()
+ if self.playback_timer.isActive():
+ self.stop_playback()
def _load_settings(self):
self.settings_file_was_loaded = False
From 109a2aa0468f9e332bbaccb7ad3675cbddcfd10c Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 08:51:51 +1100
Subject: [PATCH 115/155] fix multi-track preview rendering (2)
---
videoeditor/main.py | 178 +++++++++++++++-----------------------------
1 file changed, 59 insertions(+), 119 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 9cfca35b0..49f97c8cc 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -1258,6 +1258,7 @@ def __init__(self, project_to_load=None):
self.project_height = 720
self.playback_timer = QTimer(self)
self.playback_process = None
+ self.playback_clip = None
self._setup_ui()
self._connect_signals()
@@ -1545,64 +1546,22 @@ def _create_menu_bar(self):
def _start_playback_stream_at(self, time_sec):
self._stop_playback_stream()
- if not self.timeline.clips:
+ clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec]
+ if not clips_at_time:
return
+ clip = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0]
- w, h, fr = self.project_width, self.project_height, self.project_fps
- total_dur = self.timeline.get_total_duration()
-
- video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr}:d={total_dur}', f='lavfi')
-
- all_video_clips = sorted(
- [c for c in self.timeline.clips if c.track_type == 'video'],
- key=lambda c: c.track_index
- )
-
- input_nodes = {}
- for clip in all_video_clips:
- if clip.source_path not in input_nodes:
- if clip.media_type == 'image':
- input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=fr)
- else:
- input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
-
- clip_source_node = input_nodes[clip.source_path]
-
- if clip.media_type == 'image':
- segment_stream = (
- clip_source_node.video
- .trim(duration=clip.duration_sec)
- .setpts('PTS-STARTPTS')
- )
- else:
- segment_stream = (
- clip_source_node.video
- .trim(start=clip.clip_start_sec, duration=clip.duration_sec)
- .setpts('PTS-STARTPTS')
- )
-
- processed_segment = (
- segment_stream
- .filter('scale', w, h, force_original_aspect_ratio='decrease')
- .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
- )
-
- video_stream = ffmpeg.overlay(
- video_stream,
- processed_segment,
- enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})'
- )
-
+ self.playback_clip = clip
+ clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec
+ w, h = self.project_width, self.project_height
try:
- args = (
- video_stream
- .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=fr, ss=time_sec)
- .compile()
- )
+ args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time)
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile())
self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
except Exception as e:
- print(f"Failed to start playback stream: {e}")
- self._stop_playback_stream()
+ print(f"Failed to start playback stream: {e}"); self._stop_playback_stream()
def _stop_playback_stream(self):
if self.playback_process:
@@ -1611,6 +1570,7 @@ def _stop_playback_stream(self):
try: self.playback_process.wait(timeout=0.5)
except subprocess.TimeoutExpired: self.playback_process.kill(); self.playback_process.wait()
self.playback_process = None
+ self.playback_clip = None
def _get_media_properties(self, file_path):
"""Probes a file to get its media properties. Returns a dict or None."""
@@ -1666,33 +1626,26 @@ def _probe_for_drag(self, file_path):
return self._get_media_properties(file_path)
def get_frame_data_at_time(self, time_sec):
- clips_at_time = [
- c for c in self.timeline.clips
- if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec
- ]
-
+ clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec]
if not clips_at_time:
return (None, 0, 0)
-
- clips_at_time.sort(key=lambda c: c.track_index, reverse=True)
- clip_to_render = clips_at_time[0]
-
+ clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0]
try:
w, h = self.project_width, self.project_height
- if clip_to_render.media_type == 'image':
+ if clip_at_time.media_type == 'image':
out, _ = (
ffmpeg
- .input(clip_to_render.source_path)
+ .input(clip_at_time.source_path)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
else:
- clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec
+ clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
out, _ = (
ffmpeg
- .input(clip_to_render.source_path, ss=clip_time)
+ .input(clip_at_time.source_path, ss=clip_time)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
@@ -1704,33 +1657,24 @@ def get_frame_data_at_time(self, time_sec):
return (None, 0, 0)
def get_frame_at_time(self, time_sec):
- clips_at_time = [
- c for c in self.timeline.clips
- if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec
- ]
-
- black_pixmap = QPixmap(self.project_width, self.project_height)
- black_pixmap.fill(QColor("black"))
-
+ clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec]
+ black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
if not clips_at_time:
return black_pixmap
-
- clips_at_time.sort(key=lambda c: c.track_index, reverse=True)
- clip_to_render = clips_at_time[0]
-
+ clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0]
try:
w, h = self.project_width, self.project_height
- if clip_to_render.media_type == 'image':
+ if clip_at_time.media_type == 'image':
out, _ = (
- ffmpeg.input(clip_to_render.source_path)
+ ffmpeg.input(clip_at_time.source_path)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
else:
- clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec
- out, _ = (ffmpeg.input(clip_to_render.source_path, ss=clip_time)
+ clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
+ out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time)
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
@@ -1738,16 +1682,10 @@ def get_frame_at_time(self, time_sec):
image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888)
return QPixmap.fromImage(image)
- except ffmpeg.Error as e:
- print(f"Error extracting frame: {e.stderr}")
- return black_pixmap
+ except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap
def seek_preview(self, time_sec):
- if self.playback_timer.isActive():
- self.playback_timer.stop()
- self._stop_playback_stream()
- self.play_pause_button.setText("Play")
-
+ self._stop_playback_stream()
self.timeline_widget.playhead_pos_sec = time_sec
self.timeline_widget.update()
frame_pixmap = self.get_frame_at_time(time_sec)
@@ -1756,35 +1694,16 @@ def seek_preview(self, time_sec):
self.preview_widget.setPixmap(scaled_pixmap)
def toggle_playback(self):
- if self.playback_timer.isActive():
- self.playback_timer.stop()
- self._stop_playback_stream()
- self.play_pause_button.setText("Play")
+ if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play")
else:
if not self.timeline.clips: return
-
- playhead_time = self.timeline_widget.playhead_pos_sec
- if playhead_time >= self.timeline.get_total_duration():
- playhead_time = 0.0
- self.timeline_widget.playhead_pos_sec = 0.0
-
- self._start_playback_stream_at(playhead_time)
-
- if self.playback_process:
- self.playback_timer.start(int(1000 / self.project_fps))
- self.play_pause_button.setText("Pause")
-
- def stop_playback(self):
- self.playback_timer.stop()
- self._stop_playback_stream()
- self.play_pause_button.setText("Play")
- self.seek_preview(0.0)
+ if self.timeline_widget.playhead_pos_sec >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_sec = 0.0
+ self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause")
+ def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0.0)
def step_frame(self, direction):
if not self.timeline.clips: return
- self.playback_timer.stop()
- self._stop_playback_stream()
- self.play_pause_button.setText("Play")
+ self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream()
frame_duration = 1.0 / self.project_fps
new_time = self.timeline_widget.playhead_pos_sec + (direction * frame_duration)
self.seek_preview(max(0, min(new_time, self.timeline.get_total_duration())))
@@ -1792,13 +1711,36 @@ def step_frame(self, direction):
def advance_playback_frame(self):
frame_duration = 1.0 / self.project_fps
new_time = self.timeline_widget.playhead_pos_sec + frame_duration
- if new_time > self.timeline.get_total_duration():
- self.stop_playback()
- return
+ if new_time > self.timeline.get_total_duration(): self.stop_playback(); return
self.timeline_widget.playhead_pos_sec = new_time
self.timeline_widget.update()
+ clips_at_new_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec]
+ clip_at_new_time = None
+ if clips_at_new_time:
+ clip_at_new_time = sorted(clips_at_new_time, key=lambda c: c.track_index, reverse=True)[0]
+
+ if not clip_at_new_time:
+ self._stop_playback_stream()
+ black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
+ scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
+ self.preview_widget.setPixmap(scaled_pixmap)
+ return
+
+ if clip_at_new_time.media_type == 'image':
+ if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
+ self._stop_playback_stream()
+ self.playback_clip = clip_at_new_time
+ frame_pixmap = self.get_frame_at_time(new_time)
+ if frame_pixmap:
+ scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
+ self.preview_widget.setPixmap(scaled_pixmap)
+ return
+
+ if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
+ self._start_playback_stream_at(new_time)
+
if self.playback_process:
frame_size = self.project_width * self.project_height * 3
frame_bytes = self.playback_process.stdout.read(frame_size)
@@ -1809,8 +1751,6 @@ def advance_playback_frame(self):
self.preview_widget.setPixmap(scaled_pixmap)
else:
self._stop_playback_stream()
- if self.playback_timer.isActive():
- self.stop_playback()
def _load_settings(self):
self.settings_file_was_loaded = False
From de9ba22e1df8f140a0bacce13bfdcdeb5b2599da Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 08:59:41 +1100
Subject: [PATCH 116/155] speed up seeking
---
videoeditor/main.py | 31 ++++++++++++++++---------------
1 file changed, 16 insertions(+), 15 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 49f97c8cc..996e8ba2a 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -1544,12 +1544,19 @@ def _create_menu_bar(self):
data['action'] = action
self.windows_menu.addAction(action)
+ def _get_topmost_video_clip_at(self, time_sec):
+ """Finds the video clip on the highest track at a specific time."""
+ top_clip = None
+ for c in self.timeline.clips:
+ if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec:
+ if top_clip is None or c.track_index > top_clip.track_index:
+ top_clip = c
+ return top_clip
+
def _start_playback_stream_at(self, time_sec):
self._stop_playback_stream()
- clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec]
- if not clips_at_time:
- return
- clip = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0]
+ clip = self._get_topmost_video_clip_at(time_sec)
+ if not clip: return
self.playback_clip = clip
clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec
@@ -1626,10 +1633,9 @@ def _probe_for_drag(self, file_path):
return self._get_media_properties(file_path)
def get_frame_data_at_time(self, time_sec):
- clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec]
- if not clips_at_time:
+ clip_at_time = self._get_topmost_video_clip_at(time_sec)
+ if not clip_at_time:
return (None, 0, 0)
- clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0]
try:
w, h = self.project_width, self.project_height
if clip_at_time.media_type == 'image':
@@ -1657,11 +1663,9 @@ def get_frame_data_at_time(self, time_sec):
return (None, 0, 0)
def get_frame_at_time(self, time_sec):
- clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec]
black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
- if not clips_at_time:
- return black_pixmap
- clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0]
+ clip_at_time = self._get_topmost_video_clip_at(time_sec)
+ if not clip_at_time: return black_pixmap
try:
w, h = self.project_width, self.project_height
if clip_at_time.media_type == 'image':
@@ -1716,10 +1720,7 @@ def advance_playback_frame(self):
self.timeline_widget.playhead_pos_sec = new_time
self.timeline_widget.update()
- clips_at_new_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec]
- clip_at_new_time = None
- if clips_at_new_time:
- clip_at_new_time = sorted(clips_at_new_time, key=lambda c: c.track_index, reverse=True)[0]
+ clip_at_new_time = self._get_topmost_video_clip_at(new_time)
if not clip_at_new_time:
self._stop_playback_stream()
From 153c858d4ff60fa61ebfb437792eda6b6c9fbc90 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 09:40:51 +1100
Subject: [PATCH 117/155] add select clips, delete key and left right arrow
keys
---
videoeditor/main.py | 82 ++++++++++++++++++++++++++++++++++++---------
1 file changed, 66 insertions(+), 16 deletions(-)
diff --git a/videoeditor/main.py b/videoeditor/main.py
index 996e8ba2a..4bd3662c6 100644
--- a/videoeditor/main.py
+++ b/videoeditor/main.py
@@ -11,7 +11,7 @@
QScrollArea, QFrame, QProgressBar, QDialog,
QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox)
from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction,
- QPixmap, QImage, QDrag, QCursor)
+ QPixmap, QImage, QDrag, QCursor, QKeyEvent)
from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread,
pyqtSignal, QTimer, QByteArray, QMimeData)
@@ -87,6 +87,7 @@ class TimelineWidget(QWidget):
split_requested = pyqtSignal(object)
delete_clip_requested = pyqtSignal(object)
+ delete_clips_requested = pyqtSignal(list)
playhead_moved = pyqtSignal(float)
split_region_requested = pyqtSignal(list)
split_all_regions_requested = pyqtSignal(list)
@@ -114,7 +115,9 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.setMinimumHeight(300)
self.setMouseTracking(True)
self.setAcceptDrops(True)
+ self.setFocusPolicy(Qt.FocusPolicy.StrongFocus)
self.selection_regions = []
+ self.selected_clips = set()
self.dragging_clip = None
self.dragging_linked_clip = None
self.dragging_playhead = False
@@ -404,6 +407,12 @@ def draw_tracks_and_clips(self, painter):
color = QColor("#5A9") if self.dragging_clip and self.dragging_clip.id == clip.id else base_color
painter.fillRect(clip_rect, color)
+
+ if clip.id in self.selected_clips:
+ pen = QPen(QColor(255, 255, 0, 220), 2)
+ painter.setPen(pen)
+ painter.drawRect(clip_rect)
+
painter.setPen(QPen(QColor("#FFF"), 1))
font = QFont("Arial", 10)
painter.setFont(font)
@@ -502,11 +511,13 @@ def mousePressEvent(self, event: QMouseEvent):
return
if event.button() == Qt.MouseButton.LeftButton:
+ self.setFocus()
+
self.dragging_clip = None
self.dragging_linked_clip = None
self.dragging_playhead = False
- self.dragging_selection_region = None
self.creating_selection_region = False
+ self.dragging_selection_region = None
self.resizing_clip = None
self.resize_edge = None
self.drag_original_clip_states.clear()
@@ -528,22 +539,36 @@ def mousePressEvent(self, event: QMouseEvent):
self.update()
return
+ clicked_clip = None
for clip in reversed(self.timeline.clips):
- clip_rect = self.get_clip_rect(clip)
- if clip_rect.contains(QPointF(event.pos())):
- self.dragging_clip = clip
+ if self.get_clip_rect(clip).contains(QPointF(event.pos())):
+ clicked_clip = clip
+ break
+
+ if clicked_clip:
+ is_ctrl_pressed = bool(event.modifiers() & Qt.KeyboardModifier.ControlModifier)
+
+ if clicked_clip.id in self.selected_clips:
+ if is_ctrl_pressed:
+ self.selected_clips.remove(clicked_clip.id)
+ else:
+ if not is_ctrl_pressed:
+ self.selected_clips.clear()
+ self.selected_clips.add(clicked_clip.id)
+
+ if clicked_clip.id in self.selected_clips:
+ self.dragging_clip = clicked_clip
self.drag_start_state = self.window()._get_current_timeline_state()
- self.drag_original_clip_states[clip.id] = (clip.timeline_start_sec, clip.track_index)
+ self.drag_original_clip_states[clicked_clip.id] = (clicked_clip.timeline_start_sec, clicked_clip.track_index)
- self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id), None)
+ self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clicked_clip.group_id and c.id != clicked_clip.id), None)
if self.dragging_linked_clip:
self.drag_original_clip_states[self.dragging_linked_clip.id] = \
(self.dragging_linked_clip.timeline_start_sec, self.dragging_linked_clip.track_index)
-
self.drag_start_pos = event.pos()
- break
-
- if not self.dragging_clip:
+
+ else:
+ self.selected_clips.clear()
region_to_drag = self.get_region_at_pos(event.pos())
if region_to_drag:
self.dragging_selection_region = region_to_drag
@@ -1108,6 +1133,19 @@ def clear_all_regions(self):
self.selection_regions.clear()
self.update()
+ def keyPressEvent(self, event: QKeyEvent):
+ if event.key() == Qt.Key.Key_Delete or event.key() == Qt.Key.Key_Backspace:
+ if self.selected_clips:
+ clips_to_delete = [c for c in self.timeline.clips if c.id in self.selected_clips]
+ if clips_to_delete:
+ self.delete_clips_requested.emit(clips_to_delete)
+ elif event.key() == Qt.Key.Key_Left:
+ self.window().step_frame(-1)
+ elif event.key() == Qt.Key.Key_Right:
+ self.window().step_frame(1)
+ else:
+ super().keyPressEvent(event)
+
class SettingsDialog(QDialog):
def __init__(self, parent_settings, parent=None):
@@ -1352,6 +1390,7 @@ def _connect_signals(self):
self.timeline_widget.split_requested.connect(self.split_clip_at_playhead)
self.timeline_widget.delete_clip_requested.connect(self.delete_clip)
+ self.timeline_widget.delete_clips_requested.connect(self.delete_clips)
self.timeline_widget.playhead_moved.connect(self.seek_preview)
self.timeline_widget.split_region_requested.connect(self.on_split_region)
self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions)
@@ -2066,18 +2105,29 @@ def split_clip_at_playhead(self, clip_to_split=None):
def delete_clip(self, clip_to_delete):
+ self.delete_clips([clip_to_delete])
+
+ def delete_clips(self, clips_to_delete):
+ if not clips_to_delete: return
+
old_state = self._get_current_timeline_state()
- linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_delete.group_id and c.id != clip_to_delete.id), None)
+ ids_to_remove = set()
+ for clip in clips_to_delete:
+ ids_to_remove.add(clip.id)
+ linked_clips = [c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id]
+ for lc in linked_clips:
+ ids_to_remove.add(lc.id)
+
+ self.timeline.clips = [c for c in self.timeline.clips if c.id not in ids_to_remove]
+ self.timeline_widget.selected_clips.clear()
- if clip_to_delete in self.timeline.clips: self.timeline.clips.remove(clip_to_delete)
- if linked_clip and linked_clip in self.timeline.clips: self.timeline.clips.remove(linked_clip)
-
new_state = self._get_current_timeline_state()
- command = TimelineStateChangeCommand("Delete Clip", self.timeline, *old_state, *new_state)
+ command = TimelineStateChangeCommand(f"Delete {len(clips_to_delete)} Clip(s)", self.timeline, *old_state, *new_state)
command.undo()
self.undo_stack.push(command)
self.prune_empty_tracks()
+ self.timeline_widget.update()
def unlink_clip_pair(self, clip_to_unlink):
old_state = self._get_current_timeline_state()
From 69bb8221f7da8edf8e8c92c43ebdcc63e6bd4e1e Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 21:13:25 +1100
Subject: [PATCH 118/155] finally moved from server/client to a native plugin
---
README.md | 287 +---
videoeditor/plugins.py => plugins.py | 13 +-
main.py => plugins/wan2gp/main.py | 1311 +++++++------------
videoeditor/undo.py => undo.py | 0
videoeditor/main.py => videoeditor.py | 3 +-
videoeditor/plugins/ai_frame_joiner/main.py | 512 --------
6 files changed, 518 insertions(+), 1608 deletions(-)
rename videoeditor/plugins.py => plugins.py (96%)
rename main.py => plugins/wan2gp/main.py (69%)
rename videoeditor/undo.py => undo.py (100%)
rename videoeditor/main.py => videoeditor.py (99%)
delete mode 100644 videoeditor/plugins/ai_frame_joiner/main.py
diff --git a/README.md b/README.md
index 6464a18f0..3726bd212 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi
## Features
-- **Inline AI Video Generation With WAN2GP**: Select a region to join and it will bring up a desktop port of WAN2GP for you to generate a video inline using the start and end frames in the selected region.
+- **Inline AI Video Generation With WAN2GP**: Select a region to join and it will bring up a desktop port of WAN2GP for you to generate a video inline using the start and end frames in the selected region. You can also create frames in the selected region.
- **Multi-Track Timeline**: Arrange video and audio clips on separate tracks.
- **Project Management**: Create, save, and load projects in a `.json` format.
- **Clip Operations**:
@@ -12,6 +12,7 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi
- Split clips at the playhead.
- Create selection regions for advanced operations.
- Join/remove content within selected regions across all tracks.
+ - Link/Unlink audio tracks from video
- **Real-time Preview**: A video preview window with playback controls (Play, Pause, Stop, Frame-by-frame stepping).
- **Dynamic Track Management**: Add or remove video and audio tracks as needed.
- **FFmpeg Integration**:
@@ -23,286 +24,30 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi
## Installation
-**Follow the standard installation steps for WAN2GP below**
-
-**Run the server:**
-```bash
-python main.py
-```
-
-**Run the video editor:**
```bash
-cd videoeditor
-python main.py
-```
-
-## Screenshots
-
-
-
-
-
-
-
-
-
-# What follows is the standard WanGP Readme
-
------
-
-WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor
-
-
-WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with:
-- Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models)
-- Support for old GPUs (RTX 10XX, 20xx, ...)
-- Very Fast on the latest GPUs
-- Easy to use Full Web based interface
-- Auto download of the required model adapted to your specific architecture
-- Tools integrated to facilitate Video Generation : Mask Editor, Prompt Enhancer, Temporal and Spatial Generation, MMAudio, Video Browser, Pose / Depth / Flow extractor
-- Loras Support to customize each model
-- Queuing system : make your shopping list of videos to generate and come back later
-
-**Discord Server to get Help from Other Users and show your Best Videos:** https://discord.gg/g7efUW9jGV
-
-**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
-
-## 🔥 Latest Updates :
-### October 6 2025: WanGP v8.994 - A few last things before the Big Unknown ...
-
-This new version hasn't any new model...
-
-...but temptation to upgrade will be high as it contains a few Loras related features that may change your Life:
-- **Ready to use Loras Accelerators Profiles** per type of model that you can apply on your current *Generation Settings*. Next time I will recommend a *Lora Accelerator*, it will be only one click away. And best of all of the required Loras will be downloaded automatically. When you apply an *Accelerator Profile*, input fields like the *Number of Denoising Steps* *Activated Loras*, *Loras Multipliers* (such as "1;0 0;1" ...) will be automatically filled. However your video specific fields will be preserved, so it will be easy to switch between Profiles to experiment. With *WanGP 8.993*, the *Accelerator Loras* are now merged with *Non Accelerator Loras". Things are getting too easy...
-
-- **Embedded Loras URL** : WanGP will now try to remember every Lora URLs it sees. For instance if someone sends you some settings that contain Loras URLs or you extract the Settings of Video generated by a friend with Loras URLs, these URLs will be automatically added to *WanGP URL Cache*. Conversely everything you will share (Videos, Settings, Lset files) will contain the download URLs if they are known. You can also download directly a Lora in WanGP by using the *Download Lora* button a the bottom. The Lora will be immediatly available and added to WanGP lora URL cache. This will work with *Hugging Face* as a repository. Support for CivitAi will come as soon as someone will nice enough to post a GitHub PR ...
-
-- **.lset file** supports embedded Loras URLs. It has never been easier to share a Lora with a friend. As a reminder a .lset file can be created directly from *WanGP Web Interface* and it contains a list of Loras and their multipliers, a Prompt and Instructions how to use these loras (like the Lora's *Trigger*). So with embedded Loras URL, you can send an .lset file by email or share it on discord: it is just a 1 KB tiny text, but with it other people will be able to use Gigabytes Loras as these will be automatically downloaded.
-
-I have created the new Discord Channel **share-your-settings** where you can post your *Settings* or *Lset files*. I will be pleased to add new Loras Accelerators in the list of WanGP *Accelerators Profiles if you post some good ones there.
-
-*With the 8.993 update*, I have added support for **Scaled FP8 format**. As a sample case, I have created finetunes for the **Wan 2.2 PalinGenesis** Finetune which is quite popular recently. You will find it in 3 flavors : *t2v*, *i2v* and *Lightning Accelerated for t2v*.
-
-The *Scaled FP8 format* is widely used as it the format used by ... *ComfyUI*. So I except a flood of Finetunes in the *share-your-finetune* channel. If not it means this feature was useless and I will remove it 😈😈😈
-
-Not enough Space left on your SSD to download more models ? Would like to reuse Scaled FP8 files in your ComfyUI Folder without duplicating them ? Here comes *WanGP 8.994* **Multiple Checkpoints Folders** : you just need to move the files into different folders / hard drives or reuse existing folders and let know WanGP about it in the *Config Tab* and WanGP will be able to put all the parts together.
-
-Last but not least the Lora's documentation has been updated.
-
-*update 8.991*: full power of *Vace Lynx* unleashed with new combinations such as Landscape + Face / Clothes + Face / Injectd Frame (Start/End frames/...) + Face
-*update 8.992*: optimized gen with Lora, should be 10% faster if many loras
-*update 8.993*: Support for *Scaled FP8* format and samples *Paligenesis* finetunes, merged Loras Accelerators and Non Accelerators
-*update 8.994*: Added custom checkpoints folders
-
-### September 30 2025: WanGP v8.9 - Combinatorics
-
-This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps.
-
-*WanGP 8.9* also illustrate how existing WanGP features can be easily combined with new models. For instance with *Lynx* you will get out of the box *Video to Video* and *Image/Text to Image*.
-
-Another fun combination is *Vace* + *Lynx*, which works much better than *Vace StandIn*. I have added sliders to change the weight of Vace & Lynx to allow you to tune the effects.
-
-
-### September 28 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release
-
-So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages:
-- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*.
-
-In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion.
-
-With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise)
-
-For those of you who have a mask halo effect when Animating a character I recommend trying *SDPA attention* and to use the *FusioniX i2v* lora. If this issue persists (this will depend on the control video) you have now a choice of the two *Animate Mask Options* in *WanGP 8.76*. The old masking option which was a WanGP exclusive has been renamed *See Through Mask* because the background behind the animated character was preserved but this creates sometime visual artifacts. The new option which has the shorter name is what you may find elsewhere online. As it uses internally a much larger mask, there is no halo. However the immediate background behind the character is not preserved and may end completely different.
-
-- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package.
-
-Also because I wanted to spoil you:
-- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version.
-
-- **T2V Video 2 Video Masking**: ever wanted to apply a Lora, a process (for instance Upsampling) or a Text Prompt on only a (moving) part of a Source Video. Look no further, I have added *Masked Video 2 Video* (which works also in image2image) in the *Text 2 Video* models. As usual you just need to use *Matanyone* to creatre the mask.
-
-
-*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora
-*Update 8.72*: shadow drop of Qwen Edit Plus
-*Update 8.73*: Qwen Preview & InfiniteTalk Start image
-*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video
-*Update 8.75*: REDACTED
-*Update 8.76*: Alternate Animate masking that fixes the mask halo effect that some users have
-
-### September 15 2025: WanGP v8.6 - Attack of the Clones
-
-- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**)
-
-- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias :
-
-For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on ....
-If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*.
-
-- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area.
-
-- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video.
-
-### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ?
-
-I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof.
-
-Otherwise in the news:
-- **Cropped Input Image Prompts**: as quite often most *Image Prompts* provided (*Start Image, Input Video, Reference Image, Control Video, ...*) rarely matched your requested *Output Resolution*. In that case I used the resolution you gave either as a *Pixels Budget* or as an *Outer Canvas* for the Generated Video. However in some occasion you really want the requested Output Resolution and nothing else. Besides some models deliver much better Generations if you stick to one of their supported resolutions. In order to address this need I have added a new Output Resolution choice in the *Configuration Tab*: **Dimensions Correspond to the Ouput Weight & Height as the Prompt Images will be Cropped to fit Exactly these dimensins**. In short if needed the *Input Prompt Images* will be cropped (centered cropped for the moment). You will see this can make quite a difference for some models
-
-- *Qwen Edit* has now a new sub Tab called **Inpainting**, that lets you target with a brush which part of the *Image Prompt* you want to modify. This is quite convenient if you find that Qwen Edit modifies usually too many things. Of course, as there are more constraints for Qwen Edit don't be surprised if sometime it will return the original image unchanged. A piece of advise: describe in your *Text Prompt* where (for instance *left to the man*, *top*, ...) the parts that you want to modify are located.
-
-The mask inpainting is fully compatible with *Matanyone Mask generator*: generate first an *Image Mask* with Matanyone, transfer it to the current Image Generator and modify the mask with the *Paint Brush*. Talking about matanyone I have fixed a bug that caused a mask degradation with long videos (now WanGP Matanyone is as good as the original app and still requires 3 times less VRAM)
-
-- This **Inpainting Mask Editor** has been added also to *Vace Image Mode*. Vace is probably still one of best Image Editor today. Here is a very simple & efficient workflow that do marvels with Vace:
-Select *Vace Cocktail > Control Image Process = Perform Inpainting & Area Processed = Masked Area > Upload a Control Image, then draw your mask directly on top of the image & enter a text Prompt that describes the expected change > Generate > Below the Video Gallery click 'To Control Image' > Keep on doing more changes*.
-
-Doing more sophisticated thing Vace Image Editor works very well too: try Image Outpainting, Pose transfer, ...
-
-For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*"
-
-**update 8.55**: Flux Festival
-- **Inpainting Mode** also added for *Flux Kontext*
-- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available
-- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768
-
-Good luck with finding your way through all the Flux models names !
-
-### September 5 2025: WanGP v8.4 - Take me to Outer Space
-You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie.
-
-I have made it easier to do just that with *Qwen Edit* and *Wan*:
-- **End Frames can now be combined with Continue a Video** (and not just a Start Frame)
-- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window
-
-You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*)
-
-The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement.
-
-This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**.
-So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close)
-
-
-You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens).
-
-### September 2 2025: WanGP v8.31 - At last the pain stops
-
-- This single new feature should give you the strength to face all the potential bugs of this new release:
-**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.**
-
-- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p).
-
-- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ...
-
-
-*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs*
-
-
-See full changelog: **[Changelog](docs/CHANGELOG.md)**
-
-## 📋 Table of Contents
-
-- [🚀 Quick Start](#-quick-start)
-- [📦 Installation](#-installation)
-- [🎯 Usage](#-usage)
-- [📚 Documentation](#-documentation)
-- [🔗 Related Projects](#-related-projects)
-
-## 🚀 Quick Start
-
-**One-click installation:**
-- Get started instantly with [Pinokio App](https://pinokio.computer/)
-- Use Redtash1 [One Click Install with Sage](https://github.com/Redtash1/Wan2GP-Windows-One-Click-Install-With-Sage)
-
-**Manual installation:**
-```bash
-git clone https://github.com/deepbeepmeep/Wan2GP.git
+git clone https://github.com/Tophness/Wan2GP.git
cd Wan2GP
+git checkout video_editor
conda create -n wan2gp python=3.10.9
conda activate wan2gp
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
pip install -r requirements.txt
```
-**Run the application:**
-```bash
-python wgp.py
-```
-
-**Update the application:**
-If using Pinokio use Pinokio to update otherwise:
-Get in the directory where WanGP is installed and:
-```bash
-git pull
-pip install -r requirements.txt
-```
-
-## 🐳 Docker:
-
-**For Debian-based systems (Ubuntu, Debian, etc.):**
+## Usage
+**Run the video editor:**
```bash
-./run-docker-cuda-deb.sh
+python videoeditor.py
```
-This automated script will:
-
-- Detect your GPU model and VRAM automatically
-- Select optimal CUDA architecture for your GPU
-- Install NVIDIA Docker runtime if needed
-- Build a Docker image with all dependencies
-- Run WanGP with optimal settings for your hardware
-
-**Docker environment includes:**
-
-- NVIDIA CUDA 12.4.1 with cuDNN support
-- PyTorch 2.6.0 with CUDA 12.4 support
-- SageAttention compiled for your specific GPU architecture
-- Optimized environment variables for performance (TF32, threading, etc.)
-- Automatic cache directory mounting for faster subsequent runs
-- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally
-
-**Supported GPUs:** RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more.
-
-## 📦 Installation
-
-For detailed installation instructions for different GPU generations:
-- **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX
-
-## 🎯 Usage
-
-### Basic Usage
-- **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage
-- **[Models Overview](docs/MODELS.md)** - Available models and their capabilities
-
-### Advanced Features
-- **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization
-- **[Finetunes](docs/FINETUNES.md)** - Add manually new models to WanGP
-- **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation
-- **[Command Line Reference](docs/CLI.md)** - All available command line options
-
-## 📚 Documentation
-
-- **[Changelog](docs/CHANGELOG.md)** - Latest updates and version history
-- **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions
-
-## 📚 Video Guides
-- Nice Video that explain how to use Vace:\
-https://www.youtube.com/watch?v=FMo9oN2EAvE
-- Another Vace guide:\
-https://www.youtube.com/watch?v=T5jNiEhf9xk
-
-## 🔗 Related Projects
-
-### Other Models for the GPU Poor
-- **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators
-- **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool
-- **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux
-- **[Cosmos1GP](https://github.com/deepbeepmeep/Cosmos1GP)** - Text to world generator and image/video to world
-- **[OminiControlGP](https://github.com/deepbeepmeep/OminiControlGP)** - Flux-derived application for object transfer
-- **[YuE GP](https://github.com/deepbeepmeep/YuEGP)** - Song generator with instruments and singer's voice
-
----
+## Screenshots
+
+
+
+
-
-Made with ❤️ by DeepBeepMeep
-
+## Credits
+The AI Video Generator plugin is built from a desktop port of WAN2GP by DeepBeepMeep.
+See WAN2GP for more details.
+https://github.com/deepbeepmeep/Wan2GP
\ No newline at end of file
diff --git a/videoeditor/plugins.py b/plugins.py
similarity index 96%
rename from videoeditor/plugins.py
rename to plugins.py
index 6e50f0963..f40f6adb1 100644
--- a/videoeditor/plugins.py
+++ b/plugins.py
@@ -16,8 +16,9 @@ class VideoEditorPlugin:
Plugins should inherit from this class and be located in 'plugins/plugin_name/main.py'.
The main class in main.py must be named 'Plugin'.
"""
- def __init__(self, app_instance):
+ def __init__(self, app_instance, wgp_module=None):
self.app = app_instance
+ self.wgp = wgp_module
self.name = "Unnamed Plugin"
self.description = "No description provided."
@@ -35,8 +36,9 @@ def disable(self):
class PluginManager:
"""Manages the discovery, loading, and lifecycle of plugins."""
- def __init__(self, main_app):
+ def __init__(self, main_app, wgp_module=None):
self.app = main_app
+ self.wgp = wgp_module
self.plugins_dir = "plugins"
self.plugins = {} # { 'plugin_name': {'instance': plugin_instance, 'enabled': False, 'module_path': path} }
if not os.path.exists(self.plugins_dir):
@@ -58,7 +60,7 @@ def discover_and_load_plugins(self):
if hasattr(module, 'Plugin'):
plugin_class = getattr(module, 'Plugin')
- instance = plugin_class(self.app)
+ instance = plugin_class(self.app, self.wgp)
instance.initialize()
self.plugins[instance.name] = {
'instance': instance,
@@ -70,6 +72,8 @@ def discover_and_load_plugins(self):
print(f"Warning: {main_py_path} does not have a 'Plugin' class.")
except Exception as e:
print(f"Error loading plugin {plugin_name}: {e}")
+ import traceback
+ traceback.print_exc()
def load_enabled_plugins_from_settings(self, enabled_plugins_list):
"""Enables plugins based on the loaded settings."""
@@ -263,4 +267,5 @@ def on_install_finished(self, message, success):
self.status_label.setText(message)
self.install_btn.setEnabled(True)
if success:
- self.url_input.clear()
\ No newline at end of file
+ self.url_input.clear()
+ self.populate_list()
\ No newline at end of file
diff --git a/main.py b/plugins/wan2gp/main.py
similarity index 69%
rename from main.py
rename to plugins/wan2gp/main.py
index 773bcc9fa..ca201dcda 100644
--- a/main.py
+++ b/plugins/wan2gp/main.py
@@ -3,8 +3,11 @@
import threading
import time
import json
-import re
+import tempfile
+import shutil
+import uuid
from unittest.mock import MagicMock
+from pathlib import Path
# --- Start of Gradio Hijacking ---
# This block creates a mock Gradio module. When wgp.py is imported,
@@ -16,12 +19,10 @@ class MockGradioComponent(MagicMock):
"""A smarter mock that captures constructor arguments."""
def __init__(self, *args, **kwargs):
super().__init__(name=f"gr.{kwargs.get('elem_id', 'component')}")
- # Store the kwargs so we can inspect them later
self.kwargs = kwargs
self.value = kwargs.get('value')
self.choices = kwargs.get('choices')
- # Mock chaining methods like .click(), .then(), etc.
for method in ['then', 'change', 'click', 'input', 'select', 'upload', 'mount', 'launch', 'on', 'release']:
setattr(self, method, lambda *a, **kw: self)
@@ -34,153 +35,134 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class MockGradioError(Exception):
pass
-class MockGradioModule: # No longer inherits from MagicMock
+class MockGradioModule:
def __getattr__(self, name):
if name == 'Error':
return lambda *args, **kwargs: MockGradioError(*args)
- # Nullify functions that show pop-ups
if name in ['Info', 'Warning']:
return lambda *args, **kwargs: print(f"Intercepted gr.{name}:", *args)
return lambda *args, **kwargs: MockGradioComponent(*args, **kwargs)
sys.modules['gradio'] = MockGradioModule()
-sys.modules['gradio.gallery'] = MockGradioModule() # Also mock any submodules used
+sys.modules['gradio.gallery'] = MockGradioModule()
sys.modules['shared.gradio.gallery'] = MockGradioModule()
# --- End of Gradio Hijacking ---
-# Global placeholder for the wgp module. Will be None if import fails.
-wgp = None
-
-# Load configuration and attempt to import wgp
-MAIN_CONFIG_FILE = 'main_config.json'
-main_config = {}
-
-def load_main_config():
- global main_config
- try:
- with open(MAIN_CONFIG_FILE, 'r') as f:
- main_config = json.load(f)
- except (FileNotFoundError, json.JSONDecodeError):
- current_folder = os.path.dirname(os.path.abspath(sys.argv[0]))
- main_config = {'wgp_path': current_folder}
-
-def save_main_config():
- global main_config
- try:
- with open(MAIN_CONFIG_FILE, 'w') as f:
- json.dump(main_config, f, indent=4)
- except Exception as e:
- print(f"Error saving main_config.json: {e}")
-
-def setup_and_import_wgp():
- """Adds configured path to sys.path and tries to import wgp."""
- global wgp
- wgp_path = main_config.get('wgp_path')
- if wgp_path and os.path.isdir(wgp_path) and os.path.isfile(os.path.join(wgp_path, 'wgp.py')):
- if wgp_path not in sys.path:
- sys.path.insert(0, wgp_path)
- try:
- import wgp as wgp_module
- wgp = wgp_module
- return True
- except ImportError as e:
- print(f"Error: Failed to import wgp.py from the configured path '{wgp_path}'.\nDetails: {e}")
- wgp = None
- return False
- else:
- print("Info: WAN2GP folder path not set or invalid. Please configure it in File > Settings.")
- wgp = None
- return False
-
-# Load config and attempt import at script start
-load_main_config()
-wgp_loaded = setup_and_import_wgp()
-
-# Now import PyQt6 components
+import ffmpeg
+
from PyQt6.QtWidgets import (
- QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QTabWidget,
+ QWidget, QVBoxLayout, QHBoxLayout, QTabWidget,
QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox,
QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem,
QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem,
- QMessageBox, QRadioButton, QDialog
+ QMessageBox, QRadioButton, QSizePolicy
)
-from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QTimer, pyqtSlot
-from PyQt6.QtGui import QPixmap, QDropEvent, QAction
+from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QSize, QRectF
+from PyQt6.QtGui import QPixmap, QImage, QDropEvent
+from PyQt6.QtMultimedia import QMediaPlayer
+from PyQt6.QtMultimediaWidgets import QVideoWidget
from PIL.ImageQt import ImageQt
+# Import the base plugin class from the main application's path
+sys.path.append(str(Path(__file__).parent.parent.parent))
+from plugins import VideoEditorPlugin
-class SettingsDialog(QDialog):
- """Dialog to configure application settings like the wgp.py path."""
- def __init__(self, config, parent=None):
+class VideoResultItemWidget(QWidget):
+ """A widget to display a generated video with a hover-to-play preview and insert button."""
+ def __init__(self, video_path, plugin, parent=None):
super().__init__(parent)
- self.config = config
- self.setWindowTitle("Settings")
- self.setMinimumWidth(500)
+ self.video_path = video_path
+ self.plugin = plugin
+ self.app = plugin.app
+ self.duration = 0.0
+ self.has_audio = False
+ self.setMinimumSize(200, 180)
+ self.setMaximumHeight(190)
+
layout = QVBoxLayout(self)
- form_layout = QFormLayout()
-
- path_layout = QHBoxLayout()
- self.wgp_path_edit = QLineEdit(self.config.get('wgp_path', ''))
- self.wgp_path_edit.setPlaceholderText("Path to the folder containing wgp.py")
- browse_btn = QPushButton("Browse...")
- browse_btn.clicked.connect(self.browse_for_wgp_folder)
- path_layout.addWidget(self.wgp_path_edit)
- path_layout.addWidget(browse_btn)
-
- form_layout.addRow("WAN2GP Folder Path:", path_layout)
- layout.addLayout(form_layout)
-
- button_layout = QHBoxLayout()
- save_btn = QPushButton("Save")
- save_restart_btn = QPushButton("Save and Restart")
- cancel_btn = QPushButton("Cancel")
- button_layout.addStretch()
- button_layout.addWidget(save_btn)
- button_layout.addWidget(save_restart_btn)
- button_layout.addWidget(cancel_btn)
- layout.addLayout(button_layout)
-
- save_btn.clicked.connect(self.save_and_close)
- save_restart_btn.clicked.connect(self.save_and_restart)
- cancel_btn.clicked.connect(self.reject)
-
- def browse_for_wgp_folder(self):
- directory = QFileDialog.getExistingDirectory(self, "Select WAN2GP Folder")
- if directory:
- self.wgp_path_edit.setText(directory)
-
- def validate_path(self, path):
- if not path or not os.path.isdir(path) or not os.path.isfile(os.path.join(path, 'wgp.py')):
- QMessageBox.warning(self, "Invalid Path", "The selected folder does not contain 'wgp.py'. Please select the correct WAN2GP folder.")
- return False
- return True
-
- def _save_config(self):
- path = self.wgp_path_edit.text()
- if not self.validate_path(path):
- return False
- self.config['wgp_path'] = path
- save_main_config()
- return True
-
- def save_and_close(self):
- if self._save_config():
- QMessageBox.information(self, "Settings Saved", "Settings have been saved. Please restart the application for changes to take effect.")
- self.accept()
-
- def save_and_restart(self):
- if self._save_config():
- self.parent().close() # Close the main window before restarting
- os.execv(sys.executable, [sys.executable] + sys.argv)
+ layout.setContentsMargins(5, 5, 5, 5)
+ layout.setSpacing(5)
+
+ self.media_player = QMediaPlayer()
+ self.video_widget = QVideoWidget()
+ self.video_widget.setFixedSize(160, 90)
+ self.media_player.setVideoOutput(self.video_widget)
+ self.media_player.setSource(QUrl.fromLocalFile(self.video_path))
+ self.media_player.setLoops(QMediaPlayer.Loops.Infinite)
+
+ self.info_label = QLabel(os.path.basename(video_path))
+ self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.info_label.setWordWrap(True)
+
+ self.insert_button = QPushButton("Insert into Timeline")
+ self.insert_button.clicked.connect(self.on_insert)
+
+ h_layout = QHBoxLayout()
+ h_layout.addStretch()
+ h_layout.addWidget(self.video_widget)
+ h_layout.addStretch()
+
+ layout.addLayout(h_layout)
+ layout.addWidget(self.info_label)
+ layout.addWidget(self.insert_button)
+ self.probe_video()
+
+ def probe_video(self):
+ try:
+ probe = ffmpeg.probe(self.video_path)
+ self.duration = float(probe['format']['duration'])
+ self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', []))
+ self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)")
+ except Exception as e:
+ self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}")
+ print(f"Error probing video {self.video_path}: {e}")
+
+ def enterEvent(self, event):
+ super().enterEvent(event)
+ self.media_player.play()
+ if not self.plugin.active_region or self.duration == 0: return
+ start_sec, _ = self.plugin.active_region
+ timeline = self.app.timeline_widget
+ video_rect, audio_rect = None, None
+ x = timeline.sec_to_x(start_sec)
+ w = int(self.duration * timeline.pixels_per_second)
+ if self.plugin.insert_on_new_track:
+ video_y = timeline.TIMESCALE_HEIGHT
+ video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
+ if self.has_audio:
+ audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT
+ audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
+ else:
+ v_track_idx = 1
+ visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx
+ video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT
+ video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
+ if self.has_audio:
+ a_track_idx = 1
+ visual_a_idx = a_track_idx - 1
+ audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT
+ audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
+ timeline.set_hover_preview_rects(video_rect, audio_rect)
+
+ def leaveEvent(self, event):
+ super().leaveEvent(event)
+ self.media_player.pause()
+ self.media_player.setPosition(0)
+ self.app.timeline_widget.set_hover_preview_rects(None, None)
+
+ def on_insert(self):
+ self.media_player.stop()
+ self.media_player.setSource(QUrl())
+ self.media_player.setVideoOutput(None)
+ self.app.timeline_widget.set_hover_preview_rects(None, None)
+ self.plugin.insert_generated_clip(self.video_path)
class QueueTableWidget(QTableWidget):
- """A QTableWidget with drag-and-drop reordering for rows."""
rowsMoved = pyqtSignal(int, int)
-
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setDragEnabled(True)
@@ -195,14 +177,8 @@ def dropEvent(self, event: QDropEvent):
source_row = self.currentRow()
target_item = self.itemAt(event.position().toPoint())
dest_row = target_item.row() if target_item else self.rowCount()
-
- # Adjust destination row if moving down
- if source_row < dest_row:
- dest_row -=1
-
- if source_row != dest_row:
- self.rowsMoved.emit(source_row, dest_row)
-
+ if source_row < dest_row: dest_row -=1
+ if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row)
event.acceptProposedAction()
else:
super().dropEvent(event)
@@ -215,210 +191,78 @@ class Worker(QObject):
output = pyqtSignal()
finished = pyqtSignal()
error = pyqtSignal(str)
-
- def __init__(self, state):
+ def __init__(self, plugin, state):
super().__init__()
+ self.plugin = plugin
+ self.wgp = plugin.wgp
self.state = state
self._is_running = True
self._last_progress_phase = None
self._last_preview = None
- def send_cmd(self, cmd, data=None):
- if not self._is_running:
- return
-
def run(self):
def generation_target():
try:
- for _ in wgp.process_tasks(self.state):
- if self._is_running:
- self.output.emit()
- else:
- break
+ for _ in self.wgp.process_tasks(self.state):
+ if self._is_running: self.output.emit()
+ else: break
except Exception as e:
import traceback
print("Error in generation thread:")
traceback.print_exc()
- if "gradio.Error" in str(type(e)):
- self.error.emit(str(e))
- else:
- self.error.emit(f"An unexpected error occurred: {e}")
+ if "gradio.Error" in str(type(e)): self.error.emit(str(e))
+ else: self.error.emit(f"An unexpected error occurred: {e}")
finally:
self._is_running = False
-
gen_thread = threading.Thread(target=generation_target, daemon=True)
gen_thread.start()
-
while self._is_running:
gen = self.state.get('gen', {})
-
current_phase = gen.get("progress_phase")
if current_phase and current_phase != self._last_progress_phase:
self._last_progress_phase = current_phase
-
phase_name, step = current_phase
total_steps = gen.get("num_inference_steps", 1)
high_level_status = gen.get("progress_status", "")
-
- status_msg = wgp.merge_status_context(high_level_status, phase_name)
-
+ status_msg = self.wgp.merge_status_context(high_level_status, phase_name)
progress_args = [(step, total_steps), status_msg]
self.progress.emit(progress_args)
-
preview_img = gen.get('preview')
if preview_img is not None and preview_img is not self._last_preview:
self._last_preview = preview_img
self.preview.emit(preview_img)
gen['preview'] = None
-
time.sleep(0.1)
-
gen_thread.join()
self.finished.emit()
-class ApiBridge(QObject):
- """
- An object that lives in the main thread to receive signals from the API thread
- and forward them to the MainWindow's slots.
- """
- # --- CHANGE: Signal now includes model_type and duration_sec ---
- generateSignal = pyqtSignal(object, object, object, object, bool)
-class MainWindow(QMainWindow):
- def __init__(self):
+class WgpDesktopPluginWidget(QWidget):
+ def __init__(self, plugin):
super().__init__()
-
- self.api_bridge = ApiBridge()
-
+ self.plugin = plugin
+ self.wgp = plugin.wgp
self.widgets = {}
self.state = {}
self.worker = None
self.thread = None
self.lora_map = {}
self.full_resolution_choices = []
- self.latest_output_path = None
+ self.main_config = {}
+ self.processed_files = set()
- self.setup_menu()
-
- if not wgp:
- self.setWindowTitle("WanGP - Setup Required")
- self.setGeometry(100, 100, 600, 200)
- self.setup_placeholder_ui()
- QTimer.singleShot(100, lambda: self.show_settings_dialog(first_time=True))
- else:
- self.setWindowTitle(f"WanGP v{wgp.WanGP_version} - Qt Interface")
- self.setGeometry(100, 100, 1400, 950)
- self.setup_full_ui()
- self.apply_initial_config()
- self.connect_signals()
- self.init_wgp_state()
-
- self.api_bridge.generateSignal.connect(self._api_generate)
-
- @pyqtSlot(str)
- def _api_set_model(self, model_type):
- """This slot is executed in the main GUI thread."""
- if not model_type or not wgp: return
-
- if model_type not in wgp.models_def:
- print(f"API Error: Model type '{model_type}' is not a valid model.")
- return
-
- if self.state.get('model_type') == model_type:
- print(f"API: Model is already set to {model_type}.")
- return
-
- self.update_model_dropdowns(model_type)
-
- self._on_model_changed()
-
- if self.state.get('model_type') == model_type:
- print(f"API: Successfully set model to {model_type}.")
- else:
- print(f"API Error: Failed to set model to '{model_type}'. The model might be hidden by your current configuration.")
-
- @pyqtSlot(object, object, object, object, bool)
- def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_generation):
- """This slot is executed in the main GUI thread."""
- if model_type:
- self._api_set_model(model_type)
- if start_frame:
- self.widgets['mode_s'].setChecked(True)
- self.widgets['image_start'].setText(start_frame)
- else:
- self.widgets['mode_t'].setChecked(True)
- self.widgets['image_start'].clear()
-
- if end_frame:
- self.widgets['image_end_checkbox'].setChecked(True)
- self.widgets['image_end'].setText(end_frame)
- else:
- self.widgets['image_end_checkbox'].setChecked(False)
- self.widgets['image_end'].clear()
-
- if duration_sec is not None:
- try:
- duration = float(duration_sec)
- base_fps = 16
- fps_text = self.widgets['force_fps'].itemText(0)
- match = re.search(r'\((\d+)\s*fps\)', fps_text)
- if match:
- base_fps = int(match.group(1))
-
- video_length_frames = int(duration * base_fps)
-
- self.widgets['video_length'].setValue(video_length_frames)
- print(f"API: Calculated video length: {video_length_frames} frames for {duration:.2f}s @ {base_fps} effective FPS.")
-
- except (ValueError, TypeError) as e:
- print(f"API Error: Invalid duration_sec '{duration_sec}': {e}")
-
- if start_generation:
- self.generate_btn.click()
- print("API: Generation started.")
- else:
- print("API: Parameters set without starting generation.")
-
-
- def setup_menu(self):
- menu_bar = self.menuBar()
- file_menu = menu_bar.addMenu("&File")
-
- settings_action = QAction("&Settings", self)
- settings_action.triggered.connect(self.show_settings_dialog)
- file_menu.addAction(settings_action)
-
- def show_settings_dialog(self, first_time=False):
- dialog = SettingsDialog(main_config, self)
- if first_time:
- dialog.setWindowTitle("Initial Setup: Configure WAN2GP Path")
- dialog.exec()
-
- def setup_placeholder_ui(self):
- main_widget = QWidget()
- self.setCentralWidget(main_widget)
- layout = QVBoxLayout(main_widget)
-
- placeholder_label = QLabel(
- "
Welcome to WanGP
"
- "
The path to your WAN2GP installation (the folder containing wgp.py) is not set.
"
- "
Please go to File > Settings to configure the path.
"
- )
- placeholder_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
- placeholder_label.setWordWrap(True)
- layout.addWidget(placeholder_label)
-
- def setup_full_ui(self):
- main_widget = QWidget()
- self.setCentralWidget(main_widget)
- main_layout = QVBoxLayout(main_widget)
+ self.load_main_config()
+ self.setup_ui()
+ self.apply_initial_config()
+ self.connect_signals()
+ self.init_wgp_state()
+ def setup_ui(self):
+ main_layout = QVBoxLayout(self)
self.header_info = QLabel("Header Info")
main_layout.addWidget(self.header_info)
-
self.tabs = QTabWidget()
main_layout.addWidget(self.tabs)
-
self.setup_generator_tab()
self.setup_config_tab()
@@ -431,18 +275,12 @@ def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1
container = QWidget()
hbox = QHBoxLayout(container)
hbox.setContentsMargins(0, 0, 0, 0)
-
slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal)
slider.setRange(min_val, max_val)
slider.setValue(int(initial_val * scale))
-
value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}")
value_label.setMinimumWidth(50)
-
- slider.valueChanged.connect(
- lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}")
- )
-
+ slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}"))
hbox.addWidget(slider)
hbox.addWidget(value_label)
return container
@@ -451,30 +289,20 @@ def _create_file_input(self, name, label_text):
container = self.create_widget(QWidget, f"{name}_container")
hbox = QHBoxLayout(container)
hbox.setContentsMargins(0, 0, 0, 0)
-
line_edit = self.create_widget(QLineEdit, name)
- line_edit.setReadOnly(False) # Allow user to paste paths
line_edit.setPlaceholderText("No file selected or path pasted")
-
button = QPushButton("Browse...")
-
def open_dialog():
- # Allow selecting multiple files for reference images
if "refs" in name:
filenames, _ = QFileDialog.getOpenFileNames(self, f"Select {label_text}")
- if filenames:
- line_edit.setText(";".join(filenames))
+ if filenames: line_edit.setText(";".join(filenames))
else:
filename, _ = QFileDialog.getOpenFileName(self, f"Select {label_text}")
- if filename:
- line_edit.setText(filename)
-
+ if filename: line_edit.setText(filename)
button.clicked.connect(open_dialog)
-
clear_button = QPushButton("X")
clear_button.setFixedWidth(30)
clear_button.clicked.connect(lambda: line_edit.clear())
-
hbox.addWidget(QLabel(f"{label_text}:"))
hbox.addWidget(line_edit, 1)
hbox.addWidget(button)
@@ -485,25 +313,18 @@ def setup_generator_tab(self):
gen_tab = QWidget()
self.tabs.addTab(gen_tab, "Video Generator")
gen_layout = QHBoxLayout(gen_tab)
-
left_panel = QWidget()
left_layout = QVBoxLayout(left_panel)
gen_layout.addWidget(left_panel, 1)
-
right_panel = QWidget()
right_layout = QVBoxLayout(right_panel)
gen_layout.addWidget(right_panel, 1)
-
- # Left Panel (Inputs)
scroll_area = QScrollArea()
scroll_area.setWidgetResizable(True)
left_layout.addWidget(scroll_area)
-
options_widget = QWidget()
scroll_area.setWidget(options_widget)
options_layout = QVBoxLayout(options_widget)
-
- # Model Selection
model_layout = QHBoxLayout()
self.widgets['model_family'] = QComboBox()
self.widgets['model_base_type_choice'] = QComboBox()
@@ -513,36 +334,26 @@ def setup_generator_tab(self):
model_layout.addWidget(self.widgets['model_base_type_choice'], 3)
model_layout.addWidget(self.widgets['model_choice'], 3)
options_layout.addLayout(model_layout)
-
- # Prompt
options_layout.addWidget(QLabel("Prompt:"))
self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100)
options_layout.addWidget(self.widgets['prompt'])
-
options_layout.addWidget(QLabel("Negative Prompt:"))
self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60)
options_layout.addWidget(self.widgets['negative_prompt'])
-
- # Basic controls
basic_group = QGroupBox("Basic Options")
basic_layout = QFormLayout(basic_group)
-
res_container = QWidget()
res_hbox = QHBoxLayout(res_container)
res_hbox.setContentsMargins(0, 0, 0, 0)
res_hbox.addWidget(self.create_widget(QComboBox, 'resolution_group'), 2)
res_hbox.addWidget(self.create_widget(QComboBox, 'resolution'), 3)
basic_layout.addRow("Resolution:", res_container)
-
basic_layout.addRow("Video Length:", self._create_slider_with_label('video_length', 1, 737, 81, 1.0, 0))
basic_layout.addRow("Inference Steps:", self._create_slider_with_label('num_inference_steps', 1, 100, 30, 1.0, 0))
basic_layout.addRow("Seed:", self.create_widget(QLineEdit, 'seed', '-1'))
options_layout.addWidget(basic_group)
-
- # Generation Mode and Input Options
mode_options_group = QGroupBox("Generation Mode & Input Options")
mode_options_layout = QVBoxLayout(mode_options_group)
-
mode_hbox = QHBoxLayout()
mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_t', "Text Prompt Only"))
mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_s', "Start with Image"))
@@ -550,15 +361,12 @@ def setup_generator_tab(self):
mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_l', "Continue Last Video"))
self.widgets['mode_t'].setChecked(True)
mode_options_layout.addLayout(mode_hbox)
-
options_hbox = QHBoxLayout()
options_hbox.addWidget(self.create_widget(QCheckBox, 'image_end_checkbox', "Use End Image"))
options_hbox.addWidget(self.create_widget(QCheckBox, 'control_video_checkbox', "Use Control Video"))
options_hbox.addWidget(self.create_widget(QCheckBox, 'ref_image_checkbox', "Use Reference Image(s)"))
mode_options_layout.addLayout(options_hbox)
options_layout.addWidget(mode_options_group)
-
- # Dynamic Inputs
inputs_group = QGroupBox("Inputs")
inputs_layout = QVBoxLayout(inputs_group)
inputs_layout.addWidget(self._create_file_input('image_start', "Start Image"))
@@ -571,16 +379,12 @@ def setup_generator_tab(self):
denoising_row.addRow("Denoising Strength:", self._create_slider_with_label('denoising_strength', 0, 100, 50, 100.0, 2))
inputs_layout.addLayout(denoising_row)
options_layout.addWidget(inputs_group)
-
- # Advanced controls
self.advanced_group = self.create_widget(QGroupBox, 'advanced_group', "Advanced Options")
self.advanced_group.setCheckable(True)
self.advanced_group.setChecked(False)
advanced_layout = QVBoxLayout(self.advanced_group)
-
advanced_tabs = self.create_widget(QTabWidget, 'advanced_tabs')
advanced_layout.addWidget(advanced_tabs)
-
self._setup_adv_tab_general(advanced_tabs)
self._setup_adv_tab_loras(advanced_tabs)
self._setup_adv_tab_speed(advanced_tabs)
@@ -589,10 +393,8 @@ def setup_generator_tab(self):
self._setup_adv_tab_quality(advanced_tabs)
self._setup_adv_tab_sliding_window(advanced_tabs)
self._setup_adv_tab_misc(advanced_tabs)
-
options_layout.addWidget(self.advanced_group)
- # Right Panel (Output & Queue)
btn_layout = QHBoxLayout()
self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate")
self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue")
@@ -601,32 +403,34 @@ def setup_generator_tab(self):
btn_layout.addWidget(self.generate_btn)
btn_layout.addWidget(self.add_to_queue_btn)
right_layout.addLayout(btn_layout)
-
self.status_label = self.create_widget(QLabel, 'status_label', "Idle")
right_layout.addWidget(self.status_label)
self.progress_bar = self.create_widget(QProgressBar, 'progress_bar')
right_layout.addWidget(self.progress_bar)
-
preview_group = self.create_widget(QGroupBox, 'preview_group', "Preview")
preview_group.setCheckable(True)
preview_group.setStyleSheet("QGroupBox { border: 1px solid #cccccc; }")
preview_group_layout = QVBoxLayout(preview_group)
-
self.preview_image = self.create_widget(QLabel, 'preview_image', "")
self.preview_image.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.preview_image.setMinimumSize(200, 200)
-
preview_group_layout.addWidget(self.preview_image)
right_layout.addWidget(preview_group)
- right_layout.addWidget(QLabel("Output:"))
- self.output_gallery = self.create_widget(QListWidget, 'output_gallery')
- right_layout.addWidget(self.output_gallery)
+ results_group = QGroupBox("Generated Clips")
+ results_group.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
+ results_layout = QVBoxLayout(results_group)
+ self.results_list = self.create_widget(QListWidget, 'results_list')
+ self.results_list.setFlow(QListWidget.Flow.LeftToRight)
+ self.results_list.setWrapping(True)
+ self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust)
+ self.results_list.setSpacing(10)
+ results_layout.addWidget(self.results_list)
+ right_layout.addWidget(results_group)
right_layout.addWidget(QLabel("Queue:"))
self.queue_table = self.create_widget(QueueTableWidget, 'queue_table')
right_layout.addWidget(self.queue_table)
-
queue_btn_layout = QHBoxLayout()
self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected")
self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue")
@@ -641,14 +445,11 @@ def _setup_adv_tab_general(self, tabs):
tabs.addTab(tab, "General")
layout = QFormLayout(tab)
self.widgets['adv_general_layout'] = layout
-
guidance_group = QGroupBox("Guidance")
guidance_layout = self.create_widget(QFormLayout, 'guidance_layout', guidance_group)
guidance_layout.addRow("Guidance (CFG):", self._create_slider_with_label('guidance_scale', 10, 200, 5.0, 10.0, 1))
-
self.widgets['guidance_phases_row_index'] = guidance_layout.rowCount()
guidance_layout.addRow("Guidance Phases:", self.create_widget(QComboBox, 'guidance_phases'))
-
self.widgets['guidance2_row_index'] = guidance_layout.rowCount()
guidance_layout.addRow("Guidance 2:", self._create_slider_with_label('guidance2_scale', 10, 200, 5.0, 10.0, 1))
self.widgets['guidance3_row_index'] = guidance_layout.rowCount()
@@ -656,30 +457,24 @@ def _setup_adv_tab_general(self, tabs):
self.widgets['switch_thresh_row_index'] = guidance_layout.rowCount()
guidance_layout.addRow("Switch Threshold:", self._create_slider_with_label('switch_threshold', 0, 1000, 0, 1.0, 0))
layout.addRow(guidance_group)
-
nag_group = self.create_widget(QGroupBox, 'nag_group', "NAG (Negative Adversarial Guidance)")
nag_layout = QFormLayout(nag_group)
nag_layout.addRow("NAG Scale:", self._create_slider_with_label('NAG_scale', 10, 200, 1.0, 10.0, 1))
nag_layout.addRow("NAG Tau:", self._create_slider_with_label('NAG_tau', 10, 50, 3.5, 10.0, 1))
nag_layout.addRow("NAG Alpha:", self._create_slider_with_label('NAG_alpha', 0, 20, 0.5, 10.0, 1))
layout.addRow(nag_group)
-
self.widgets['solver_row_container'] = QWidget()
solver_hbox = QHBoxLayout(self.widgets['solver_row_container'])
solver_hbox.setContentsMargins(0,0,0,0)
solver_hbox.addWidget(QLabel("Sampler Solver:"))
solver_hbox.addWidget(self.create_widget(QComboBox, 'sample_solver'))
layout.addRow(self.widgets['solver_row_container'])
-
self.widgets['flow_shift_row_index'] = layout.rowCount()
layout.addRow("Shift Scale:", self._create_slider_with_label('flow_shift', 10, 250, 3.0, 10.0, 1))
-
self.widgets['audio_guidance_row_index'] = layout.rowCount()
layout.addRow("Audio Guidance:", self._create_slider_with_label('audio_guidance_scale', 10, 200, 4.0, 10.0, 1))
-
self.widgets['repeat_generation_row_index'] = layout.rowCount()
layout.addRow("Repeat Generations:", self._create_slider_with_label('repeat_generation', 1, 25, 1, 1.0, 0))
-
combo = self.create_widget(QComboBox, 'multi_images_gen_type')
combo.addItem("Generate all combinations", 0)
combo.addItem("Match images and texts", 1)
@@ -706,7 +501,6 @@ def _setup_adv_tab_speed(self, tabs):
combo.addItem("Tea Cache", "tea")
combo.addItem("Mag Cache", "mag")
layout.addRow("Cache Type:", combo)
-
combo = self.create_widget(QComboBox, 'skip_steps_multiplier')
combo.addItem("x1.5 speed up", 1.5)
combo.addItem("x1.75 speed up", 1.75)
@@ -725,13 +519,11 @@ def _setup_adv_tab_postproc(self, tabs):
combo.addItem("Rife x2 frames/s", "rife2")
combo.addItem("Rife x4 frames/s", "rife4")
layout.addRow("Temporal Upsampling:", combo)
-
combo = self.create_widget(QComboBox, 'spatial_upsampling')
combo.addItem("Disabled", "")
combo.addItem("Lanczos x1.5", "lanczos1.5")
combo.addItem("Lanczos x2.0", "lanczos2")
layout.addRow("Spatial Upsampling:", combo)
-
layout.addRow("Film Grain Intensity:", self._create_slider_with_label('film_grain_intensity', 0, 100, 0, 100.0, 2))
layout.addRow("Film Grain Saturation:", self._create_slider_with_label('film_grain_saturation', 0, 100, 0.5, 100.0, 2))
@@ -751,7 +543,6 @@ def _setup_adv_tab_quality(self, tabs):
tab = QWidget()
tabs.addTab(tab, "Quality")
layout = QVBoxLayout(tab)
-
slg_group = self.create_widget(QGroupBox, 'slg_group', "Skip Layer Guidance")
slg_layout = QFormLayout(slg_group)
slg_combo = self.create_widget(QComboBox, 'slg_switch')
@@ -761,25 +552,20 @@ def _setup_adv_tab_quality(self, tabs):
slg_layout.addRow("Start %:", self._create_slider_with_label('slg_start_perc', 0, 100, 10, 1.0, 0))
slg_layout.addRow("End %:", self._create_slider_with_label('slg_end_perc', 0, 100, 90, 1.0, 0))
layout.addWidget(slg_group)
-
quality_form = QFormLayout()
self.widgets['quality_form_layout'] = quality_form
-
apg_combo = self.create_widget(QComboBox, 'apg_switch')
apg_combo.addItem("OFF", 0)
apg_combo.addItem("ON", 1)
self.widgets['apg_switch_row_index'] = quality_form.rowCount()
quality_form.addRow("Adaptive Projected Guidance:", apg_combo)
-
cfg_star_combo = self.create_widget(QComboBox, 'cfg_star_switch')
cfg_star_combo.addItem("OFF", 0)
cfg_star_combo.addItem("ON", 1)
self.widgets['cfg_star_switch_row_index'] = quality_form.rowCount()
quality_form.addRow("Classifier-Free Guidance Star:", cfg_star_combo)
-
self.widgets['cfg_zero_step_row_index'] = quality_form.rowCount()
quality_form.addRow("CFG Zero below Layer:", self._create_slider_with_label('cfg_zero_step', -1, 39, -1, 1.0, 0))
-
combo = self.create_widget(QComboBox, 'min_frames_if_references')
combo.addItem("Disabled (1 frame)", 1)
combo.addItem("Generate 5 frames", 5)
@@ -795,7 +581,6 @@ def _setup_adv_tab_sliding_window(self, tabs):
self.widgets['sliding_window_tab_index'] = tabs.count()
tabs.addTab(tab, "Sliding Window")
layout = QFormLayout(tab)
-
layout.addRow("Window Size:", self._create_slider_with_label('sliding_window_size', 5, 257, 129, 1.0, 0))
layout.addRow("Overlap:", self._create_slider_with_label('sliding_window_overlap', 1, 97, 5, 1.0, 0))
layout.addRow("Color Correction:", self._create_slider_with_label('sliding_window_color_correction_strength', 0, 100, 0, 100.0, 2))
@@ -807,23 +592,18 @@ def _setup_adv_tab_misc(self, tabs):
tabs.addTab(tab, "Misc")
layout = QFormLayout(tab)
self.widgets['misc_layout'] = layout
-
riflex_combo = self.create_widget(QComboBox, 'RIFLEx_setting')
riflex_combo.addItem("Auto", 0)
riflex_combo.addItem("Always ON", 1)
riflex_combo.addItem("Always OFF", 2)
self.widgets['riflex_row_index'] = layout.rowCount()
layout.addRow("RIFLEx Setting:", riflex_combo)
-
fps_combo = self.create_widget(QComboBox, 'force_fps')
layout.addRow("Force FPS:", fps_combo)
-
profile_combo = self.create_widget(QComboBox, 'override_profile')
profile_combo.addItem("Default Profile", -1)
- for text, val in wgp.memory_profile_choices:
- profile_combo.addItem(text.split(':')[0], val)
+ for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val)
layout.addRow("Override Memory Profile:", profile_combo)
-
combo = self.create_widget(QComboBox, 'multi_prompts_gen_type')
combo.addItem("Generate new Video per line", 0)
combo.addItem("Use line for new Sliding Window", 1)
@@ -833,19 +613,15 @@ def setup_config_tab(self):
config_tab = QWidget()
self.tabs.addTab(config_tab, "Configuration")
main_layout = QVBoxLayout(config_tab)
-
self.config_status_label = QLabel("Apply changes for them to take effect. Some may require a restart.")
main_layout.addWidget(self.config_status_label)
-
config_tabs = QTabWidget()
main_layout.addWidget(config_tabs)
-
config_tabs.addTab(self._create_general_config_tab(), "General")
config_tabs.addTab(self._create_performance_config_tab(), "Performance")
config_tabs.addTab(self._create_extensions_config_tab(), "Extensions")
config_tabs.addTab(self._create_outputs_config_tab(), "Outputs")
config_tabs.addTab(self._create_notifications_config_tab(), "Notifications")
-
self.apply_config_btn = QPushButton("Apply Changes")
self.apply_config_btn.clicked.connect(self._on_apply_config_changes)
main_layout.addWidget(self.apply_config_btn)
@@ -856,18 +632,15 @@ def _create_scrollable_form_tab(self):
scroll_area.setWidgetResizable(True)
layout = QVBoxLayout(tab_widget)
layout.addWidget(scroll_area)
-
content_widget = QWidget()
form_layout = QFormLayout(content_widget)
scroll_area.setWidget(content_widget)
-
return tab_widget, form_layout
def _create_config_combo(self, form_layout, label, key, choices, default_value):
combo = QComboBox()
- for text, data in choices:
- combo.addItem(text, data)
- index = combo.findData(wgp.server_config.get(key, default_value))
+ for text, data in choices: combo.addItem(text, data)
+ index = combo.findData(self.wgp.server_config.get(key, default_value))
if index != -1: combo.setCurrentIndex(index)
self.widgets[f'config_{key}'] = combo
form_layout.addRow(label, combo)
@@ -876,34 +649,27 @@ def _create_config_slider(self, form_layout, label, key, min_val, max_val, defau
container = QWidget()
hbox = QHBoxLayout(container)
hbox.setContentsMargins(0,0,0,0)
-
slider = QSlider(Qt.Orientation.Horizontal)
slider.setRange(min_val, max_val)
slider.setSingleStep(step)
- slider.setValue(wgp.server_config.get(key, default_value))
-
+ slider.setValue(self.wgp.server_config.get(key, default_value))
value_label = QLabel(str(slider.value()))
value_label.setMinimumWidth(40)
slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v)))
-
hbox.addWidget(slider)
hbox.addWidget(value_label)
-
self.widgets[f'config_{key}'] = slider
form_layout.addRow(label, container)
def _create_config_checklist(self, form_layout, label, key, choices, default_value):
list_widget = QListWidget()
list_widget.setMinimumHeight(100)
- current_values = wgp.server_config.get(key, default_value)
+ current_values = self.wgp.server_config.get(key, default_value)
for text, data in choices:
item = QListWidgetItem(text)
item.setData(Qt.ItemDataRole.UserRole, data)
item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable)
- if data in current_values:
- item.setCheckState(Qt.CheckState.Checked)
- else:
- item.setCheckState(Qt.CheckState.Unchecked)
+ item.setCheckState(Qt.CheckState.Checked if data in current_values else Qt.CheckState.Unchecked)
list_widget.addItem(item)
self.widgets[f'config_{key}'] = list_widget
form_layout.addRow(label, list_widget)
@@ -919,48 +685,24 @@ def _create_config_textbox(self, form_layout, label, key, default_value, multi_l
def _create_general_config_tab(self):
tab, form = self._create_scrollable_form_tab()
-
- _, _, dropdown_choices = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False)
- self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, wgp.transformer_types)
-
- self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [
- ("Two Levels (Family > Model)", 0),
- ("Three Levels (Family > Base > Finetune)", 1)
- ], 1)
-
- self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [
- ("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1),
- ("Dimensions are Output Width/Height (Cropped)", 2)], 0)
-
- self._create_config_combo(form, "Attention Type:", "attention_mode", [
- ("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"),
- ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto")
-
- self._create_config_combo(form, "Metadata Handling:", "metadata_type", [
- ("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata")
-
- self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [
- ("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], [])
-
- self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [
- ("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10),
- ("Keep last 20", 20), ("Keep last 30", 30)], 5)
-
+ _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False)
+ self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types)
+ self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1)
+ self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0)
+ self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto")
+ self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata")
+ self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], [])
+ self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5)
self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0)
-
- self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier",
- [(f"x{i}", i) for i in range(1, 8)], 1)
-
- checkpoints_paths_text = "\n".join(wgp.server_config.get("checkpoints_paths", wgp.fl.default_checkpoints_paths))
+ self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1)
+ checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths))
checkpoints_textbox = QTextEdit()
checkpoints_textbox.setPlainText(checkpoints_paths_text)
checkpoints_textbox.setAcceptRichText(False)
checkpoints_textbox.setMinimumHeight(60)
self.widgets['config_checkpoints_paths'] = checkpoints_textbox
form.addRow("Checkpoints Paths:", checkpoints_textbox)
-
self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default")
-
return tab
def _create_performance_config_tab(self):
@@ -974,13 +716,11 @@ def _create_performance_config_tab(self):
self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl")
self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0)
self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1)
- self._create_config_combo(form, "Memory Profile:", "profile", wgp.memory_profile_choices, wgp.profile_type.LowRAM_LowVRAM)
+ self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM)
self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100)
-
release_ram_btn = QPushButton("Force Release Models from RAM")
release_ram_btn.clicked.connect(self._on_release_ram)
form.addRow(release_ram_btn)
-
return tab
def _create_extensions_config_tab(self):
@@ -1005,12 +745,12 @@ def _create_notifications_config_tab(self):
return tab
def init_wgp_state(self):
+ wgp = self.wgp
initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type)
- all_models, _, _ = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False)
+ dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
+ _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
all_model_ids = [m[1] for m in all_models]
- if initial_model not in all_model_ids:
- initial_model = wgp.transformer_type
-
+ if initial_model not in all_model_ids: initial_model = wgp.transformer_type
state_dict = {}
state_dict["model_filename"] = wgp.get_model_filename(initial_model, wgp.transformer_quantization, wgp.transformer_dtype_policy)
state_dict["model_type"] = initial_model
@@ -1019,49 +759,36 @@ def init_wgp_state(self):
state_dict["last_model_per_type"] = wgp.server_config.get("last_model_per_type", {})
state_dict["last_resolution_per_group"] = wgp.server_config.get("last_resolution_per_group", {})
state_dict["gen"] = {"queue": []}
-
self.state = state_dict
self.advanced_group.setChecked(wgp.advanced)
-
self.update_model_dropdowns(initial_model)
self.refresh_ui_from_model_change(initial_model)
- self._update_input_visibility() # Set initial visibility
+ self._update_input_visibility()
def update_model_dropdowns(self, current_model_type):
+ wgp = self.wgp
family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type)
-
- self.widgets['model_family'].blockSignals(True)
- self.widgets['model_base_type_choice'].blockSignals(True)
- self.widgets['model_choice'].blockSignals(True)
-
- self.widgets['model_family'].clear()
- if family_mock.choices:
- for display_name, internal_key in family_mock.choices:
- self.widgets['model_family'].addItem(display_name, internal_key)
- index = self.widgets['model_family'].findData(family_mock.value)
- if index != -1: self.widgets['model_family'].setCurrentIndex(index)
-
- self.widgets['model_base_type_choice'].clear()
- if base_type_mock.choices:
- for label, value in base_type_mock.choices:
- self.widgets['model_base_type_choice'].addItem(label, value)
- index = self.widgets['model_base_type_choice'].findData(base_type_mock.value)
- if index != -1: self.widgets['model_base_type_choice'].setCurrentIndex(index)
- self.widgets['model_base_type_choice'].setVisible(base_type_mock.kwargs.get('visible', True))
-
- self.widgets['model_choice'].clear()
- if choice_mock.choices:
- for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value)
- index = self.widgets['model_choice'].findData(choice_mock.value)
- if index != -1: self.widgets['model_choice'].setCurrentIndex(index)
- self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True))
-
- self.widgets['model_family'].blockSignals(False)
- self.widgets['model_base_type_choice'].blockSignals(False)
- self.widgets['model_choice'].blockSignals(False)
+ for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]:
+ combo = self.widgets[combo_name]
+ combo.blockSignals(True)
+ combo.clear()
+ if mock.choices:
+ for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key)
+ index = combo.findData(mock.value)
+ if index != -1: combo.setCurrentIndex(index)
+
+ is_visible = True
+ if hasattr(mock, 'kwargs') and isinstance(mock.kwargs, dict):
+ is_visible = mock.kwargs.get('visible', True)
+ elif hasattr(mock, 'visible'):
+ is_visible = mock.visible
+ combo.setVisible(is_visible)
+
+ combo.blockSignals(False)
def refresh_ui_from_model_change(self, model_type):
"""Update UI controls with default settings when the model is changed."""
+ wgp = self.wgp
self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode))
ui_defaults = wgp.get_default_settings(model_type)
wgp.set_model_settings(self.state, model_type, ui_defaults)
@@ -1128,10 +855,8 @@ def refresh_ui_from_model_change(self, model_type):
self.widgets['resolution_group'].blockSignals(False)
self.widgets['resolution'].blockSignals(False)
-
- for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']:
- if name in self.widgets:
- self.widgets[name].clear()
+ for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'image_refs', 'audio_source']:
+ if name in self.widgets: self.widgets[name].clear()
guidance_layout = self.widgets['guidance_layout']
guidance_max = model_def.get("guidance_max_phases", 1)
@@ -1155,7 +880,6 @@ def refresh_ui_from_model_change(self, model_type):
misc_layout = self.widgets['misc_layout']
misc_layout.setRowVisible(self.widgets['riflex_row_index'], not (recammaster or ltxv or diffusion_forcing))
-
index = self.widgets['multi_images_gen_type'].findData(ui_defaults.get('multi_images_gen_type', 0))
if index != -1: self.widgets['multi_images_gen_type'].setCurrentIndex(index)
@@ -1170,12 +894,11 @@ def refresh_ui_from_model_change(self, model_type):
guidance3_val = ui_defaults.get("guidance3_scale", 5.0)
self.widgets['guidance3_scale'].setValue(int(guidance3_val * 10))
self.widgets['guidance3_scale_label'].setText(f"{guidance3_val:.1f}")
- self.widgets['guidance_phases'].clear()
+ self.widgets['guidance_phases'].clear()
if guidance_max >= 1: self.widgets['guidance_phases'].addItem("One Phase", 1)
if guidance_max >= 2: self.widgets['guidance_phases'].addItem("Two Phases", 2)
if guidance_max >= 3: self.widgets['guidance_phases'].addItem("Three Phases", 3)
-
index = self.widgets['guidance_phases'].findData(ui_defaults.get("guidance_phases", 1))
if index != -1: self.widgets['guidance_phases'].setCurrentIndex(index)
@@ -1227,9 +950,7 @@ def refresh_ui_from_model_change(self, model_type):
selected_loras = ui_defaults.get('activated_loras', [])
for i in range(lora_list_widget.count()):
item = lora_list_widget.item(i)
- is_selected = any(item.text() == os.path.basename(p) for p in selected_loras)
- if is_selected:
- item.setSelected(True)
+ if any(item.text() == os.path.basename(p) for p in selected_loras): item.setSelected(True)
self.widgets['loras_multipliers'].setText(ui_defaults.get('loras_multipliers', ''))
skip_cache_val = ui_defaults.get('skip_steps_cache_type', "")
@@ -1329,11 +1050,11 @@ def refresh_ui_from_model_change(self, model_type):
for widget in self.widgets.values():
if hasattr(widget, 'blockSignals'): widget.blockSignals(False)
+
self._update_dynamic_ui()
self._update_input_visibility()
def _update_dynamic_ui(self):
- """Update UI visibility based on current selections."""
phases = self.widgets['guidance_phases'].currentData() or 1
guidance_layout = self.widgets['guidance_layout']
guidance_layout.setRowVisible(self.widgets['guidance2_row_index'], phases >= 2)
@@ -1341,83 +1062,45 @@ def _update_dynamic_ui(self):
guidance_layout.setRowVisible(self.widgets['switch_thresh_row_index'], phases >= 2)
def _update_generation_mode_visibility(self, model_def):
- """Shows/hides the main generation mode options based on the selected model."""
allowed = model_def.get("image_prompt_types_allowed", "")
-
choices = []
- if "T" in allowed or not allowed:
- choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T"))
- if "S" in allowed:
- choices.append(("Start Video with Image", "S"))
- if "V" in allowed:
- choices.append(("Continue Video", "V"))
- if "L" in allowed:
- choices.append(("Continue Last Video", "L"))
-
- button_map = {
- "T": self.widgets['mode_t'],
- "S": self.widgets['mode_s'],
- "V": self.widgets['mode_v'],
- "L": self.widgets['mode_l'],
- }
-
- for btn in button_map.values():
- btn.setVisible(False)
-
+ if "T" in allowed or not allowed: choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T"))
+ if "S" in allowed: choices.append(("Start Video with Image", "S"))
+ if "V" in allowed: choices.append(("Continue Video", "V"))
+ if "L" in allowed: choices.append(("Continue Last Video", "L"))
+ button_map = { "T": self.widgets['mode_t'], "S": self.widgets['mode_s'], "V": self.widgets['mode_v'], "L": self.widgets['mode_l'] }
+ for btn in button_map.values(): btn.setVisible(False)
allowed_values = [c[1] for c in choices]
for label, value in choices:
if value in button_map:
btn = button_map[value]
btn.setText(label)
btn.setVisible(True)
-
- current_checked_value = None
- for value, btn in button_map.items():
- if btn.isChecked():
- current_checked_value = value
- break
-
- # If the currently selected mode is now hidden, reset to a visible default
+ current_checked_value = next((value for value, btn in button_map.items() if btn.isChecked()), None)
if current_checked_value is None or not button_map[current_checked_value].isVisible():
- if allowed_values:
- button_map[allowed_values[0]].setChecked(True)
-
-
+ if allowed_values: button_map[allowed_values[0]].setChecked(True)
end_image_visible = "E" in allowed
self.widgets['image_end_checkbox'].setVisible(end_image_visible)
- if not end_image_visible:
- self.widgets['image_end_checkbox'].setChecked(False)
-
- # Control Video Checkbox (Based on model_def.get("guide_preprocessing"))
+ if not end_image_visible: self.widgets['image_end_checkbox'].setChecked(False)
control_video_visible = model_def.get("guide_preprocessing") is not None
self.widgets['control_video_checkbox'].setVisible(control_video_visible)
- if not control_video_visible:
- self.widgets['control_video_checkbox'].setChecked(False)
-
- # Reference Image Checkbox (Based on model_def.get("image_ref_choices"))
+ if not control_video_visible: self.widgets['control_video_checkbox'].setChecked(False)
ref_image_visible = model_def.get("image_ref_choices") is not None
self.widgets['ref_image_checkbox'].setVisible(ref_image_visible)
- if not ref_image_visible:
- self.widgets['ref_image_checkbox'].setChecked(False)
-
+ if not ref_image_visible: self.widgets['ref_image_checkbox'].setChecked(False)
def _update_input_visibility(self):
- """Shows/hides input fields based on the selected generation mode."""
is_s_mode = self.widgets['mode_s'].isChecked()
is_v_mode = self.widgets['mode_v'].isChecked()
is_l_mode = self.widgets['mode_l'].isChecked()
-
use_end = self.widgets['image_end_checkbox'].isChecked() and self.widgets['image_end_checkbox'].isVisible()
use_control = self.widgets['control_video_checkbox'].isChecked() and self.widgets['control_video_checkbox'].isVisible()
use_ref = self.widgets['ref_image_checkbox'].isChecked() and self.widgets['ref_image_checkbox'].isVisible()
-
self.widgets['image_start_container'].setVisible(is_s_mode)
self.widgets['video_source_container'].setVisible(is_v_mode)
-
end_checkbox_enabled = is_s_mode or is_v_mode or is_l_mode
self.widgets['image_end_checkbox'].setEnabled(end_checkbox_enabled)
self.widgets['image_end_container'].setVisible(use_end and end_checkbox_enabled)
-
self.widgets['video_guide_container'].setVisible(use_control)
self.widgets['video_mask_container'].setVisible(use_control)
self.widgets['image_refs_container'].setVisible(use_ref)
@@ -1428,17 +1111,14 @@ def connect_signals(self):
self.widgets['model_choice'].currentIndexChanged.connect(self._on_model_changed)
self.widgets['resolution_group'].currentIndexChanged.connect(self._on_resolution_group_changed)
self.widgets['guidance_phases'].currentIndexChanged.connect(self._update_dynamic_ui)
-
self.widgets['mode_t'].toggled.connect(self._update_input_visibility)
self.widgets['mode_s'].toggled.connect(self._update_input_visibility)
self.widgets['mode_v'].toggled.connect(self._update_input_visibility)
self.widgets['mode_l'].toggled.connect(self._update_input_visibility)
-
self.widgets['image_end_checkbox'].toggled.connect(self._update_input_visibility)
self.widgets['control_video_checkbox'].toggled.connect(self._update_input_visibility)
self.widgets['ref_image_checkbox'].toggled.connect(self._update_input_visibility)
self.widgets['preview_group'].toggled.connect(self._on_preview_toggled)
-
self.generate_btn.clicked.connect(self._on_generate)
self.add_to_queue_btn.clicked.connect(self._on_add_to_queue)
self.remove_queue_btn.clicked.connect(self._on_remove_selected_from_queue)
@@ -1446,214 +1126,169 @@ def connect_signals(self):
self.abort_btn.clicked.connect(self._on_abort)
self.queue_table.rowsMoved.connect(self._on_queue_rows_moved)
+ def load_main_config(self):
+ try:
+ with open('main_config.json', 'r') as f: self.main_config = json.load(f)
+ except (FileNotFoundError, json.JSONDecodeError): self.main_config = {'preview_visible': False}
+
+ def save_main_config(self):
+ try:
+ with open('main_config.json', 'w') as f: json.dump(self.main_config, f, indent=4)
+ except Exception as e: print(f"Error saving main_config.json: {e}")
+
def apply_initial_config(self):
- is_visible = main_config.get('preview_visible', True)
+ is_visible = self.main_config.get('preview_visible', True)
self.widgets['preview_group'].setChecked(is_visible)
self.widgets['preview_image'].setVisible(is_visible)
def _on_preview_toggled(self, checked):
self.widgets['preview_image'].setVisible(checked)
- main_config['preview_visible'] = checked
- save_main_config()
+ self.main_config['preview_visible'] = checked
+ self.save_main_config()
- def _on_family_changed(self, index):
+ def _on_family_changed(self):
family = self.widgets['model_family'].currentData()
if not family or not self.state: return
-
- base_type_mock, choice_mock = wgp.change_model_family(self.state, family)
+ base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family)
+ if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict):
+ is_visible_base = base_type_mock.kwargs.get('visible', True)
+ elif hasattr(base_type_mock, 'visible'):
+ is_visible_base = base_type_mock.visible
+ else:
+ is_visible_base = True
+
self.widgets['model_base_type_choice'].blockSignals(True)
self.widgets['model_base_type_choice'].clear()
if base_type_mock.choices:
- for label, value in base_type_mock.choices:
- self.widgets['model_base_type_choice'].addItem(label, value)
- index = self.widgets['model_base_type_choice'].findData(base_type_mock.value)
- if index != -1:
- self.widgets['model_base_type_choice'].setCurrentIndex(index)
- self.widgets['model_base_type_choice'].setVisible(base_type_mock.kwargs.get('visible', True))
+ for label, value in base_type_mock.choices: self.widgets['model_base_type_choice'].addItem(label, value)
+ self.widgets['model_base_type_choice'].setCurrentIndex(self.widgets['model_base_type_choice'].findData(base_type_mock.value))
+ self.widgets['model_base_type_choice'].setVisible(is_visible_base)
self.widgets['model_base_type_choice'].blockSignals(False)
-
+
+ if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict):
+ is_visible_choice = choice_mock.kwargs.get('visible', True)
+ elif hasattr(choice_mock, 'visible'):
+ is_visible_choice = choice_mock.visible
+ else:
+ is_visible_choice = True
+
self.widgets['model_choice'].blockSignals(True)
self.widgets['model_choice'].clear()
if choice_mock.choices:
- for label, value in choice_mock.choices:
- self.widgets['model_choice'].addItem(label, value)
- index = self.widgets['model_choice'].findData(choice_mock.value)
- if index != -1:
- self.widgets['model_choice'].setCurrentIndex(index)
- self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True))
+ for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value)
+ self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value))
+ self.widgets['model_choice'].setVisible(is_visible_choice)
self.widgets['model_choice'].blockSignals(False)
-
+
self._on_model_changed()
- def _on_base_type_changed(self, index):
+ def _on_base_type_changed(self):
family = self.widgets['model_family'].currentData()
base_type = self.widgets['model_base_type_choice'].currentData()
if not family or not base_type or not self.state: return
-
- base_type_mock, choice_mock = wgp.change_model_base_types(self.state, family, base_type)
-
+ base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type)
+
+ if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict):
+ is_visible_choice = choice_mock.kwargs.get('visible', True)
+ elif hasattr(choice_mock, 'visible'):
+ is_visible_choice = choice_mock.visible
+ else:
+ is_visible_choice = True
+
self.widgets['model_choice'].blockSignals(True)
self.widgets['model_choice'].clear()
if choice_mock.choices:
- for label, value in choice_mock.choices:
- self.widgets['model_choice'].addItem(label, value)
- index = self.widgets['model_choice'].findData(choice_mock.value)
- if index != -1:
- self.widgets['model_choice'].setCurrentIndex(index)
- self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True))
+ for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value)
+ self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value))
+ self.widgets['model_choice'].setVisible(is_visible_choice)
self.widgets['model_choice'].blockSignals(False)
-
self._on_model_changed()
def _on_model_changed(self):
model_type = self.widgets['model_choice'].currentData()
- if not model_type or model_type == self.state['model_type']: return
- wgp.change_model(self.state, model_type)
+ if not model_type or model_type == self.state.get('model_type'): return
+ self.wgp.change_model(self.state, model_type)
self.refresh_ui_from_model_change(model_type)
def _on_resolution_group_changed(self):
selected_group = self.widgets['resolution_group'].currentText()
- if not selected_group or not hasattr(self, 'full_resolution_choices'):
- return
-
+ if not selected_group or not hasattr(self, 'full_resolution_choices'): return
model_type = self.state['model_type']
- model_def = wgp.get_model_def(model_type)
+ model_def = self.wgp.get_model_def(model_type)
model_resolutions = model_def.get("resolutions", None)
-
group_resolution_choices = []
if model_resolutions is None:
- group_resolution_choices = [res for res in self.full_resolution_choices if wgp.categorize_resolution(res[1]) == selected_group]
- else:
- return
-
- last_resolution_per_group = self.state.get("last_resolution_per_group", {})
- last_resolution = last_resolution_per_group.get(selected_group, "")
-
- is_last_res_valid = any(last_resolution == res[1] for res in group_resolution_choices)
- if not is_last_res_valid and group_resolution_choices:
+ group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group]
+ else: return
+ last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "")
+ if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices:
last_resolution = group_resolution_choices[0][1]
-
self.widgets['resolution'].blockSignals(True)
self.widgets['resolution'].clear()
- for label, value in group_resolution_choices:
- self.widgets['resolution'].addItem(label, value)
-
- index = self.widgets['resolution'].findData(last_resolution)
- if index != -1:
- self.widgets['resolution'].setCurrentIndex(index)
+ for label, value in group_resolution_choices: self.widgets['resolution'].addItem(label, value)
+ self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution))
self.widgets['resolution'].blockSignals(False)
def collect_inputs(self):
- """Gather all settings from UI widgets into a dictionary."""
- # Start with all possible defaults. This dictionary will be modified and returned.
- full_inputs = wgp.get_current_model_settings(self.state).copy()
-
- # Add dummy/default values for UI elements present in Gradio but not yet in PyQt.
- # These are expected by the backend logic.
+ full_inputs = self.wgp.get_current_model_settings(self.state).copy()
full_inputs['lset_name'] = ""
- # The PyQt UI is focused on generating videos, so image_mode is 0.
full_inputs['image_mode'] = 0
-
- # Defensively initialize keys that are accessed directly in wgp.py but may not
- # be in the saved model settings or fully implemented in the UI yet.
- # This prevents both KeyErrors and TypeErrors for missing arguments.
- expected_keys = {
- "audio_guide": None, "audio_guide2": None, "image_guide": None,
- "image_mask": None, "speakers_locations": "", "frames_positions": "",
- "keep_frames_video_guide": "", "keep_frames_video_source": "",
- "video_guide_outpainting": "", "switch_threshold2": 0,
- "model_switch_phase": 1, "batch_size": 1,
- "control_net_weight_alt": 1.0,
- "image_refs_relative_size": 50,
- }
+ expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, }
for key, default_value in expected_keys.items():
- if key not in full_inputs:
- full_inputs[key] = default_value
-
- # Overwrite defaults with values from the PyQt UI widgets
+ if key not in full_inputs: full_inputs[key] = default_value
full_inputs['prompt'] = self.widgets['prompt'].toPlainText()
full_inputs['negative_prompt'] = self.widgets['negative_prompt'].toPlainText()
full_inputs['resolution'] = self.widgets['resolution'].currentData()
full_inputs['video_length'] = self.widgets['video_length'].value()
full_inputs['num_inference_steps'] = self.widgets['num_inference_steps'].value()
full_inputs['seed'] = int(self.widgets['seed'].text())
-
- # Build prompt_type strings based on mode selections
image_prompt_type = ""
video_prompt_type = ""
-
- if self.widgets['mode_s'].isChecked():
- image_prompt_type = 'S'
- elif self.widgets['mode_v'].isChecked():
- image_prompt_type = 'V'
- elif self.widgets['mode_l'].isChecked():
- image_prompt_type = 'L'
- else: # mode_t is checked
- image_prompt_type = ''
-
- if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked():
- image_prompt_type += 'E'
-
- if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked():
- video_prompt_type += 'V' # This 'V' is for Control Video (V2V)
- if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked():
- video_prompt_type += 'I' # 'I' for Reference Image
-
+ if self.widgets['mode_s'].isChecked(): image_prompt_type = 'S'
+ elif self.widgets['mode_v'].isChecked(): image_prompt_type = 'V'
+ elif self.widgets['mode_l'].isChecked(): image_prompt_type = 'L'
+ if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): image_prompt_type += 'E'
+ if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): video_prompt_type += 'V'
+ if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): video_prompt_type += 'I'
full_inputs['image_prompt_type'] = image_prompt_type
full_inputs['video_prompt_type'] = video_prompt_type
-
- # File Inputs
for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']:
- if name in self.widgets:
- path = self.widgets[name].text()
- full_inputs[name] = path if path else None
-
+ if name in self.widgets: full_inputs[name] = self.widgets[name].text() or None
paths = self.widgets['image_refs'].text().split(';')
full_inputs['image_refs'] = [p.strip() for p in paths if p.strip()] if paths and paths[0] else None
-
full_inputs['denoising_strength'] = self.widgets['denoising_strength'].value() / 100.0
-
if self.advanced_group.isChecked():
full_inputs['guidance_scale'] = self.widgets['guidance_scale'].value() / 10.0
full_inputs['guidance_phases'] = self.widgets['guidance_phases'].currentData()
full_inputs['guidance2_scale'] = self.widgets['guidance2_scale'].value() / 10.0
full_inputs['guidance3_scale'] = self.widgets['guidance3_scale'].value() / 10.0
full_inputs['switch_threshold'] = self.widgets['switch_threshold'].value()
-
full_inputs['NAG_scale'] = self.widgets['NAG_scale'].value() / 10.0
full_inputs['NAG_tau'] = self.widgets['NAG_tau'].value() / 10.0
full_inputs['NAG_alpha'] = self.widgets['NAG_alpha'].value() / 10.0
-
full_inputs['sample_solver'] = self.widgets['sample_solver'].currentData()
full_inputs['flow_shift'] = self.widgets['flow_shift'].value() / 10.0
full_inputs['audio_guidance_scale'] = self.widgets['audio_guidance_scale'].value() / 10.0
full_inputs['repeat_generation'] = self.widgets['repeat_generation'].value()
full_inputs['multi_images_gen_type'] = self.widgets['multi_images_gen_type'].currentData()
-
- lora_list_widget = self.widgets['activated_loras']
- selected_items = lora_list_widget.selectedItems()
+ selected_items = self.widgets['activated_loras'].selectedItems()
full_inputs['activated_loras'] = [self.lora_map[item.text()] for item in selected_items if item.text() in self.lora_map]
full_inputs['loras_multipliers'] = self.widgets['loras_multipliers'].toPlainText()
-
full_inputs['skip_steps_cache_type'] = self.widgets['skip_steps_cache_type'].currentData()
full_inputs['skip_steps_multiplier'] = self.widgets['skip_steps_multiplier'].currentData()
full_inputs['skip_steps_start_step_perc'] = self.widgets['skip_steps_start_step_perc'].value()
-
full_inputs['temporal_upsampling'] = self.widgets['temporal_upsampling'].currentData()
full_inputs['spatial_upsampling'] = self.widgets['spatial_upsampling'].currentData()
full_inputs['film_grain_intensity'] = self.widgets['film_grain_intensity'].value() / 100.0
full_inputs['film_grain_saturation'] = self.widgets['film_grain_saturation'].value() / 100.0
-
full_inputs['MMAudio_setting'] = self.widgets['MMAudio_setting'].currentData()
full_inputs['MMAudio_prompt'] = self.widgets['MMAudio_prompt'].text()
full_inputs['MMAudio_neg_prompt'] = self.widgets['MMAudio_neg_prompt'].text()
-
full_inputs['RIFLEx_setting'] = self.widgets['RIFLEx_setting'].currentData()
full_inputs['force_fps'] = self.widgets['force_fps'].currentData()
full_inputs['override_profile'] = self.widgets['override_profile'].currentData()
full_inputs['multi_prompts_gen_type'] = self.widgets['multi_prompts_gen_type'].currentData()
-
full_inputs['slg_switch'] = self.widgets['slg_switch'].currentData()
full_inputs['slg_start_perc'] = self.widgets['slg_start_perc'].value()
full_inputs['slg_end_perc'] = self.widgets['slg_end_perc'].value()
@@ -1661,13 +1296,11 @@ def collect_inputs(self):
full_inputs['cfg_star_switch'] = self.widgets['cfg_star_switch'].currentData()
full_inputs['cfg_zero_step'] = self.widgets['cfg_zero_step'].value()
full_inputs['min_frames_if_references'] = self.widgets['min_frames_if_references'].currentData()
-
full_inputs['sliding_window_size'] = self.widgets['sliding_window_size'].value()
full_inputs['sliding_window_overlap'] = self.widgets['sliding_window_overlap'].value()
full_inputs['sliding_window_color_correction_strength'] = self.widgets['sliding_window_color_correction_strength'].value() / 100.0
full_inputs['sliding_window_overlap_noise'] = self.widgets['sliding_window_overlap_noise'].value()
full_inputs['sliding_window_discard_last_frames'] = self.widgets['sliding_window_discard_last_frames'].value()
-
return full_inputs
def _prepare_state_for_generation(self):
@@ -1678,62 +1311,44 @@ def _prepare_state_for_generation(self):
def _on_generate(self):
try:
is_running = self.thread and self.thread.isRunning()
- self._add_task_to_queue_and_update_ui()
- if not is_running:
- self.start_generation()
- except Exception:
- import traceback
- traceback.print_exc()
-
+ self._add_task_to_queue()
+ if not is_running: self.start_generation()
+ except Exception as e:
+ import traceback; traceback.print_exc()
def _on_add_to_queue(self):
try:
- self._add_task_to_queue_and_update_ui()
- except Exception:
- import traceback
- traceback.print_exc()
-
- def _add_task_to_queue_and_update_ui(self):
- self._add_task_to_queue()
- self.update_queue_table()
-
+ self._add_task_to_queue()
+ except Exception as e:
+ import traceback; traceback.print_exc()
+
def _add_task_to_queue(self):
- queue_size_before = len(self.state["gen"]["queue"])
all_inputs = self.collect_inputs()
- keys_to_remove = ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']
- for key in keys_to_remove:
- all_inputs.pop(key, None)
-
+ for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None)
all_inputs['state'] = self.state
- wgp.set_model_settings(self.state, self.state['model_type'], all_inputs)
-
+ self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs)
self.state["validate_success"] = 1
- wgp.process_prompt_and_add_tasks(self.state, self.state['model_type'])
-
+ self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type'])
+ self.update_queue_table()
def start_generation(self):
- if not self.state['gen']['queue']:
- return
+ if not self.state['gen']['queue']: return
self._prepare_state_for_generation()
self.generate_btn.setEnabled(False)
self.add_to_queue_btn.setEnabled(True)
-
self.thread = QThread()
- self.worker = Worker(self.state)
+ self.worker = Worker(self.plugin, self.state)
self.worker.moveToThread(self.thread)
-
self.thread.started.connect(self.worker.run)
self.worker.finished.connect(self.thread.quit)
self.worker.finished.connect(self.worker.deleteLater)
self.thread.finished.connect(self.thread.deleteLater)
self.thread.finished.connect(self.on_generation_finished)
-
self.worker.status.connect(self.status_label.setText)
self.worker.progress.connect(self.update_progress)
self.worker.preview.connect(self.update_preview)
- self.worker.output.connect(self.update_queue_and_gallery)
+ self.worker.output.connect(self.update_queue_and_results)
self.worker.error.connect(self.on_generation_error)
-
self.thread.start()
self.update_queue_table()
@@ -1743,17 +1358,11 @@ def on_generation_finished(self):
self.progress_bar.setValue(0)
self.generate_btn.setEnabled(True)
self.add_to_queue_btn.setEnabled(False)
- self.thread = None
- self.worker = None
+ self.thread = None; self.worker = None
self.update_queue_table()
def on_generation_error(self, err_msg):
- msg_box = QMessageBox()
- msg_box.setIcon(QMessageBox.Icon.Critical)
- msg_box.setText("Generation Error")
- msg_box.setInformativeText(str(err_msg))
- msg_box.setWindowTitle("Error")
- msg_box.exec()
+ QMessageBox.critical(self, "Generation Error", str(err_msg))
self.on_generation_finished()
def update_progress(self, data):
@@ -1762,221 +1371,283 @@ def update_progress(self, data):
self.progress_bar.setMaximum(total)
self.progress_bar.setValue(step)
self.status_label.setText(str(data[1]))
- if step <= 1:
- self.update_queue_table()
- elif len(data) > 1:
- self.status_label.setText(str(data[1]))
+ if step <= 1: self.update_queue_table()
+ elif len(data) > 1: self.status_label.setText(str(data[1]))
def update_preview(self, pil_image):
- if pil_image:
+ if pil_image and self.widgets['preview_group'].isChecked():
q_image = ImageQt(pil_image)
pixmap = QPixmap.fromImage(q_image)
- self.preview_image.setPixmap(pixmap.scaled(
- self.preview_image.size(),
- Qt.AspectRatioMode.KeepAspectRatio,
- Qt.TransformationMode.SmoothTransformation
- ))
+ self.preview_image.setPixmap(pixmap.scaled(self.preview_image.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation))
- def update_queue_and_gallery(self):
+ def update_queue_and_results(self):
self.update_queue_table()
-
file_list = self.state.get('gen', {}).get('file_list', [])
- self.output_gallery.clear()
- self.output_gallery.addItems(file_list)
- if file_list:
- self.output_gallery.setCurrentRow(len(file_list) - 1)
- self.latest_output_path = file_list[-1] # <-- ADDED for API access
+ for file_path in file_list:
+ if file_path not in self.processed_files:
+ self.add_result_item(file_path)
+ self.processed_files.add(file_path)
+
+ def add_result_item(self, video_path):
+ item_widget = VideoResultItemWidget(video_path, self.plugin)
+ list_item = QListWidgetItem(self.results_list)
+ list_item.setSizeHint(item_widget.sizeHint())
+ self.results_list.addItem(list_item)
+ self.results_list.setItemWidget(list_item, item_widget)
def update_queue_table(self):
- with wgp.lock:
+ with self.wgp.lock:
queue = self.state.get('gen', {}).get('queue', [])
is_running = self.thread and self.thread.isRunning()
queue_to_display = queue if is_running else [None] + queue
-
- table_data = wgp.get_queue_table(queue_to_display)
-
+ table_data = self.wgp.get_queue_table(queue_to_display)
self.queue_table.setRowCount(0)
self.queue_table.setRowCount(len(table_data))
self.queue_table.setColumnCount(4)
self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"])
-
for row_idx, row_data in enumerate(table_data):
- prompt_html = row_data[1]
- try:
- prompt_text = prompt_html.split('>')[1].split('<')[0]
- except IndexError:
- prompt_text = str(row_data[1])
-
- self.queue_table.setItem(row_idx, 0, QTableWidgetItem(str(row_data[0])))
- self.queue_table.setItem(row_idx, 1, QTableWidgetItem(prompt_text))
- self.queue_table.setItem(row_idx, 2, QTableWidgetItem(str(row_data[2])))
- self.queue_table.setItem(row_idx, 3, QTableWidgetItem(str(row_data[3])))
-
+ prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1])
+ for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]):
+ self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data)))
self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
self.queue_table.resizeColumnsToContents()
def _on_remove_selected_from_queue(self):
selected_row = self.queue_table.currentRow()
- if selected_row < 0:
- return
-
- with wgp.lock:
+ if selected_row < 0: return
+ with self.wgp.lock:
is_running = self.thread and self.thread.isRunning()
offset = 1 if is_running else 0
-
queue = self.state.get('gen', {}).get('queue', [])
- if len(queue) > selected_row + offset:
- queue.pop(selected_row + offset)
+ if len(queue) > selected_row + offset: queue.pop(selected_row + offset)
self.update_queue_table()
def _on_queue_rows_moved(self, source_row, dest_row):
- with wgp.lock:
+ with self.wgp.lock:
queue = self.state.get('gen', {}).get('queue', [])
is_running = self.thread and self.thread.isRunning()
offset = 1 if is_running else 0
-
real_source_idx = source_row + offset
real_dest_idx = dest_row + offset
-
moved_item = queue.pop(real_source_idx)
queue.insert(real_dest_idx, moved_item)
self.update_queue_table()
def _on_clear_queue(self):
- wgp.clear_queue_action(self.state)
+ self.wgp.clear_queue_action(self.state)
self.update_queue_table()
def _on_abort(self):
if self.worker:
- wgp.abort_generation(self.state)
+ self.wgp.abort_generation(self.state)
self.status_label.setText("Aborting...")
self.worker._is_running = False
def _on_release_ram(self):
- wgp.release_RAM()
+ self.wgp.release_RAM()
QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.")
def _on_apply_config_changes(self):
changes = {}
-
list_widget = self.widgets['config_transformer_types']
- checked_items = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
- changes['transformer_types_choices'] = checked_items
-
+ changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
list_widget = self.widgets['config_preload_model_policy']
- checked_items = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
- changes['preload_model_policy_choice'] = checked_items
-
+ changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData()
changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText()
-
for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]:
changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
-
for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]:
changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value()
-
for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]:
changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
-
for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]:
widget = self.widgets[f'config_{key}']
changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text()
-
changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData()
changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value()
-
changes['last_resolution_choice'] = self.widgets['resolution'].currentData()
-
try:
- msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = wgp.apply_changes(self.state, **changes)
+ msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes)
self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.")
-
- self.header_info.setText(wgp.generate_header(self.state['model_type'], wgp.compile, wgp.attention_mode))
-
+ self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode))
if family_mock.choices is not None or choice_mock.choices is not None:
- self.update_model_dropdowns(wgp.transformer_type)
- self.refresh_ui_from_model_change(wgp.transformer_type)
-
+ self.update_model_dropdowns(self.wgp.transformer_type)
+ self.refresh_ui_from_model_change(self.wgp.transformer_type)
except Exception as e:
self.config_status_label.setText(f"Error applying changes: {e}")
- import traceback
- traceback.print_exc()
-
- def closeEvent(self, event):
- if wgp:
- wgp.autosave_queue()
- save_main_config()
- event.accept()
-
-# =====================================================================
-# --- START OF API SERVER ADDITION ---
-# =====================================================================
-try:
- from flask import Flask, request, jsonify
- FLASK_AVAILABLE = True
-except ImportError:
- FLASK_AVAILABLE = False
- print("Flask not installed. API server will not be available. Please run: pip install Flask")
-
-# Global reference to the main window, to be populated after instantiation.
-main_window_instance = None
-api_server = Flask(__name__) if FLASK_AVAILABLE else None
-
-def run_api_server():
- """Function to run the Flask server."""
- if api_server:
- print("Starting API server on http://127.0.0.1:5100")
- api_server.run(port=5100, host='127.0.0.1', debug=False)
-
-if FLASK_AVAILABLE:
-
- @api_server.route('/api/generate', methods=['POST'])
- def generate():
- if not main_window_instance:
- return jsonify({"error": "Application not ready"}), 503
-
- data = request.json
- start_frame = data.get('start_frame')
- end_frame = data.get('end_frame')
- # --- CHANGE: Get duration_sec and model_type ---
- duration_sec = data.get('duration_sec')
- model_type = data.get('model_type')
- start_generation = data.get('start_generation', False)
-
- # --- CHANGE: Emit the signal with the new parameters ---
- main_window_instance.api_bridge.generateSignal.emit(
- start_frame, end_frame, duration_sec, model_type, start_generation
- )
-
- if start_generation:
- return jsonify({"message": "Parameters set and generation request sent."})
+ import traceback; traceback.print_exc()
+
+class Plugin(VideoEditorPlugin):
+ def initialize(self):
+ self.name = "WanGP AI Generator"
+ self.description = "Uses the integrated WanGP library to generate video clips."
+ self.client_widget = WgpDesktopPluginWidget(self)
+ self.dock_widget = None
+ self.active_region = None; self.temp_dir = None
+ self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None
+
+ def enable(self):
+ if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name)
+ self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu)
+ self.app.status_label.setText(f"{self.name}: Enabled.")
+
+ def disable(self):
+ try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu)
+ except TypeError: pass
+ self._cleanup_temp_dir()
+ if self.client_widget.worker: self.client_widget._on_abort()
+ self.app.status_label.setText(f"{self.name}: Disabled.")
+
+ def _cleanup_temp_dir(self):
+ if self.temp_dir and os.path.exists(self.temp_dir):
+ shutil.rmtree(self.temp_dir)
+ self.temp_dir = None
+
+ def _reset_state(self):
+ self.active_region = None; self.insert_on_new_track = False
+ self.start_frame_path = None; self.end_frame_path = None
+ self.client_widget.processed_files.clear()
+ self.client_widget.results_list.clear()
+ self.client_widget.widgets['image_start'].clear()
+ self.client_widget.widgets['image_end'].clear()
+ self.client_widget.widgets['video_source'].clear()
+ self._cleanup_temp_dir()
+ self.app.status_label.setText(f"{self.name}: Ready.")
+
+ def on_timeline_context_menu(self, menu, event):
+ region = self.app.timeline_widget.get_region_at_pos(event.pos())
+ if region:
+ menu.addSeparator()
+ start_sec, end_sec = region
+ start_data, _, _ = self.app.get_frame_data_at_time(start_sec)
+ end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+ if start_data and end_data:
+ join_action = menu.addAction("Join Frames With WanGP")
+ join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
+ join_action_new_track = menu.addAction("Join Frames With WanGP (New Track)")
+ join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
+ create_action = menu.addAction("Create Video With WanGP")
+ create_action.triggered.connect(lambda: self.setup_creator_for_region(region))
+
+ def setup_generator_for_region(self, region, on_new_track=False):
+ self._reset_state()
+ self.active_region = region
+ self.insert_on_new_track = on_new_track
+
+ model_to_set = 'i2v_2_2'
+ dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
+ _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ if any(model_to_set == m[1] for m in all_models):
+ if self.client_widget.state.get('model_type') != model_to_set:
+ self.client_widget.update_model_dropdowns(model_to_set)
+ self.client_widget._on_model_changed()
else:
- return jsonify({"message": "Parameters set without starting generation."})
-
-
- @api_server.route('/api/outputs', methods=['GET'])
- def get_outputs():
- if not main_window_instance:
- return jsonify({"error": "Application not ready"}), 503
-
- file_list = main_window_instance.state.get('gen', {}).get('file_list', [])
- return jsonify({"outputs": file_list})
-
-# =====================================================================
-# --- END OF API SERVER ADDITION ---
-# =====================================================================
-
-if __name__ == '__main__':
- app = QApplication(sys.argv)
-
- window = MainWindow()
- main_window_instance = window
- window.show()
+ print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
+
+ start_sec, end_sec = region
+ start_data, w, h = self.app.get_frame_data_at_time(start_sec)
+ end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+ if not start_data or not end_data:
+ QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.")
+ return
+ try:
+ self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_")
+ self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png")
+ self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png")
+ QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path)
+ QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
+
+ duration_sec = end_sec - start_sec
+ wgp = self.client_widget.wgp
+ model_type = self.client_widget.state['model_type']
+ fps = wgp.get_model_fps(model_type)
+ video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+
+ self.client_widget.widgets['video_length'].setValue(video_length_frames)
+ self.client_widget.widgets['mode_s'].setChecked(True)
+ self.client_widget.widgets['image_end_checkbox'].setChecked(True)
+ self.client_widget.widgets['image_start'].setText(self.start_frame_path)
+ self.client_widget.widgets['image_end'].setText(self.end_frame_path)
+
+ self.client_widget._update_input_visibility()
- if FLASK_AVAILABLE:
- api_thread = threading.Thread(target=run_api_server, daemon=True)
- api_thread.start()
+ except Exception as e:
+ QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}")
+ self._cleanup_temp_dir()
+ return
+ self.app.status_label.setText(f"WanGP: Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.")
+ self.dock_widget.show()
+ self.dock_widget.raise_()
+
+ def setup_creator_for_region(self, region):
+ self._reset_state()
+ self.active_region = region
+ self.insert_on_new_track = True
+
+ model_to_set = 't2v_2_2'
+ dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
+ _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ if any(model_to_set == m[1] for m in all_models):
+ if self.client_widget.state.get('model_type') != model_to_set:
+ self.client_widget.update_model_dropdowns(model_to_set)
+ self.client_widget._on_model_changed()
+ else:
+ print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.")
- sys.exit(app.exec())
\ No newline at end of file
+ start_sec, end_sec = region
+ duration_sec = end_sec - start_sec
+ wgp = self.client_widget.wgp
+ model_type = self.client_widget.state['model_type']
+ fps = wgp.get_model_fps(model_type)
+ video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+
+ self.client_widget.widgets['video_length'].setValue(video_length_frames)
+ self.client_widget.widgets['mode_t'].setChecked(True)
+
+ self.app.status_label.setText(f"WanGP: Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.")
+ self.dock_widget.show()
+ self.dock_widget.raise_()
+
+ def insert_generated_clip(self, video_path):
+ from videoeditor import TimelineClip
+ if not self.active_region:
+ self.app.status_label.setText("WanGP Error: No active region to insert into."); return
+ if not os.path.exists(video_path):
+ self.app.status_label.setText(f"WanGP Error: Output file not found: {video_path}"); return
+ start_sec, end_sec = self.active_region
+ def complex_insertion_action():
+ self.app._add_media_files_to_project([video_path])
+ media_info = self.app.media_properties.get(video_path)
+ if not media_info: raise ValueError("Could not probe inserted clip.")
+ actual_duration, has_audio = media_info['duration'], media_info['has_audio']
+ if self.insert_on_new_track:
+ self.app.timeline.num_video_tracks += 1
+ video_track_index = self.app.timeline.num_video_tracks
+ audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None
+ else:
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
+ clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec]
+ for clip in clips_to_remove:
+ if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip)
+ video_track_index, audio_track_index = 1, 1 if has_audio else None
+ group_id = str(uuid.uuid4())
+ new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id)
+ self.app.timeline.add_clip(new_clip)
+ if audio_track_index:
+ if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index
+ audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id)
+ self.app.timeline.add_clip(audio_clip)
+ try:
+ self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action)
+ self.app.prune_empty_tracks()
+ self.app.status_label.setText("AI clip inserted successfully.")
+ for i in range(self.client_widget.results_list.count()):
+ widget = self.client_widget.results_list.itemWidget(self.client_widget.results_list.item(i))
+ if widget and widget.video_path == video_path:
+ self.client_widget.results_list.takeItem(i); break
+ except Exception as e:
+ import traceback; traceback.print_exc()
+ self.app.status_label.setText(f"WanGP Error during clip insertion: {e}")
\ No newline at end of file
diff --git a/videoeditor/undo.py b/undo.py
similarity index 100%
rename from videoeditor/undo.py
rename to undo.py
diff --git a/videoeditor/main.py b/videoeditor.py
similarity index 99%
rename from videoeditor/main.py
rename to videoeditor.py
index 4bd3662c6..1c7466957 100644
--- a/videoeditor/main.py
+++ b/videoeditor.py
@@ -1,3 +1,4 @@
+import wgp
import sys
import os
import uuid
@@ -1288,7 +1289,7 @@ def __init__(self, project_to_load=None):
self.is_shutting_down = False
self._load_settings()
- self.plugin_manager = PluginManager(self)
+ self.plugin_manager = PluginManager(self, wgp)
self.plugin_manager.discover_and_load_plugins()
self.project_fps = 25.0
diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py
deleted file mode 100644
index 410333d28..000000000
--- a/videoeditor/plugins/ai_frame_joiner/main.py
+++ /dev/null
@@ -1,512 +0,0 @@
-import sys
-import os
-import tempfile
-import shutil
-import requests
-from pathlib import Path
-import ffmpeg
-import uuid
-
-from PyQt6.QtWidgets import (
- QWidget, QVBoxLayout, QHBoxLayout, QFormLayout,
- QLineEdit, QPushButton, QLabel, QMessageBox, QCheckBox, QListWidget,
- QListWidgetItem, QGroupBox
-)
-from PyQt6.QtGui import QImage, QPixmap
-from PyQt6.QtCore import pyqtSignal, Qt, QSize, QRectF, QUrl, QTimer
-from PyQt6.QtMultimedia import QMediaPlayer
-from PyQt6.QtMultimediaWidgets import QVideoWidget
-
-sys.path.append(str(Path(__file__).parent.parent.parent))
-from plugins import VideoEditorPlugin
-
-API_BASE_URL = "http://127.0.0.1:5100"
-
-class VideoResultItemWidget(QWidget):
- def __init__(self, video_path, plugin, parent=None):
- super().__init__(parent)
- self.video_path = video_path
- self.plugin = plugin
- self.app = plugin.app
- self.duration = 0.0
- self.has_audio = False
-
- self.setMinimumSize(200, 180)
- self.setMaximumHeight(190)
-
- layout = QVBoxLayout(self)
- layout.setContentsMargins(5, 5, 5, 5)
- layout.setSpacing(5)
-
- self.media_player = QMediaPlayer()
- self.video_widget = QVideoWidget()
- self.video_widget.setFixedSize(160, 90)
- self.media_player.setVideoOutput(self.video_widget)
- self.media_player.setSource(QUrl.fromLocalFile(self.video_path))
- self.media_player.setLoops(QMediaPlayer.Loops.Infinite)
-
- self.info_label = QLabel(os.path.basename(video_path))
- self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
- self.info_label.setWordWrap(True)
-
- self.insert_button = QPushButton("Insert into Timeline")
- self.insert_button.clicked.connect(self.on_insert)
-
- h_layout = QHBoxLayout()
- h_layout.addStretch()
- h_layout.addWidget(self.video_widget)
- h_layout.addStretch()
-
- layout.addLayout(h_layout)
- layout.addWidget(self.info_label)
- layout.addWidget(self.insert_button)
-
- self.probe_video()
-
- def probe_video(self):
- try:
- probe = ffmpeg.probe(self.video_path)
- self.duration = float(probe['format']['duration'])
- self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', []))
-
- self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)")
-
- except Exception as e:
- self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}")
- print(f"Error probing video {self.video_path}: {e}")
-
- def enterEvent(self, event):
- super().enterEvent(event)
- self.media_player.play()
-
- if not self.plugin.active_region or self.duration == 0:
- return
-
- start_sec, _ = self.plugin.active_region
- timeline = self.app.timeline_widget
-
- video_rect, audio_rect = None, None
-
- x = timeline.sec_to_x(start_sec)
- w = int(self.duration * timeline.pixels_per_second)
-
- if self.plugin.insert_on_new_track:
- video_y = timeline.TIMESCALE_HEIGHT
- video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
- if self.has_audio:
- audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT
- audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
- else:
- v_track_idx = 1
- visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx
- video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT
- video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
- if self.has_audio:
- a_track_idx = 1
- visual_a_idx = a_track_idx - 1
- audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT
- audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
-
- timeline.set_hover_preview_rects(video_rect, audio_rect)
-
- def leaveEvent(self, event):
- super().leaveEvent(event)
- self.media_player.pause()
- self.media_player.setPosition(0)
- self.app.timeline_widget.set_hover_preview_rects(None, None)
-
- def on_insert(self):
- self.media_player.stop()
- self.media_player.setSource(QUrl())
- self.media_player.setVideoOutput(None)
- self.app.timeline_widget.set_hover_preview_rects(None, None)
- self.plugin.insert_generated_clip(self.video_path)
-
-class WgpClientWidget(QWidget):
- status_updated = pyqtSignal(str)
-
- def __init__(self, plugin):
- super().__init__()
- self.plugin = plugin
- self.processed_files = set()
-
- layout = QVBoxLayout(self)
- form_layout = QFormLayout()
-
- self.previews_widget = QWidget()
- previews_layout = QHBoxLayout(self.previews_widget)
- previews_layout.setContentsMargins(0, 0, 0, 0)
-
- start_preview_layout = QVBoxLayout()
- start_preview_layout.addWidget(QLabel("Start Frame"))
- self.start_frame_preview = QLabel("N/A")
- self.start_frame_preview.setAlignment(Qt.AlignmentFlag.AlignCenter)
- self.start_frame_preview.setFixedSize(160, 90)
- self.start_frame_preview.setStyleSheet("background-color: #222; color: #888;")
- start_preview_layout.addWidget(self.start_frame_preview)
- previews_layout.addLayout(start_preview_layout)
-
- end_preview_layout = QVBoxLayout()
- end_preview_layout.addWidget(QLabel("End Frame"))
- self.end_frame_preview = QLabel("N/A")
- self.end_frame_preview.setAlignment(Qt.AlignmentFlag.AlignCenter)
- self.end_frame_preview.setFixedSize(160, 90)
- self.end_frame_preview.setStyleSheet("background-color: #222; color: #888;")
- end_preview_layout.addWidget(self.end_frame_preview)
- previews_layout.addLayout(end_preview_layout)
-
- self.start_frame_input = QLineEdit()
- self.end_frame_input = QLineEdit()
- self.duration_input = QLineEdit()
-
- self.autostart_checkbox = QCheckBox("Start Generation Immediately")
- self.autostart_checkbox.setChecked(True)
-
- self.generate_button = QPushButton("Generate")
-
- layout.addWidget(self.previews_widget)
- form_layout.addRow(self.autostart_checkbox)
- form_layout.addRow(self.generate_button)
-
- layout.addLayout(form_layout)
-
- # --- Results Area ---
- results_group = QGroupBox("Generated Clips (Hover to play, Click button to insert)")
- results_layout = QVBoxLayout()
- self.results_list = QListWidget()
- self.results_list.setFlow(QListWidget.Flow.LeftToRight)
- self.results_list.setWrapping(True)
- self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust)
- self.results_list.setSpacing(10)
- results_layout.addWidget(self.results_list)
- results_group.setLayout(results_layout)
- layout.addWidget(results_group)
-
- layout.addStretch()
-
- self.generate_button.clicked.connect(self.generate)
-
- self.poll_timer = QTimer(self)
- self.poll_timer.setInterval(3000)
- self.poll_timer.timeout.connect(self.poll_for_output)
-
- def set_previews(self, start_pixmap, end_pixmap):
- if start_pixmap:
- self.start_frame_preview.setPixmap(start_pixmap)
- else:
- self.start_frame_preview.setText("N/A")
- self.start_frame_preview.setPixmap(QPixmap())
-
- if end_pixmap:
- self.end_frame_preview.setPixmap(end_pixmap)
- else:
- self.end_frame_preview.setText("N/A")
- self.end_frame_preview.setPixmap(QPixmap())
-
- def start_polling(self):
- self.poll_timer.start()
-
- def stop_polling(self):
- self.poll_timer.stop()
-
- def handle_api_error(self, response, action="performing action"):
- try:
- error_msg = response.json().get("error", "Unknown error")
- except requests.exceptions.JSONDecodeError:
- error_msg = response.text
- self.status_updated.emit(f"AI Joiner Error: {error_msg}")
- QMessageBox.warning(self, "API Error", f"Failed while {action}.\n\nServer response:\n{error_msg}")
-
- def check_server_status(self):
- try:
- requests.get(f"{API_BASE_URL}/api/outputs", timeout=1)
- self.status_updated.emit("AI Joiner: Connected to WanGP server.")
- return True
- except requests.exceptions.ConnectionError:
- self.status_updated.emit("AI Joiner Error: Cannot connect to WanGP server.")
- QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.")
- return False
-
- def generate(self):
- payload = {}
-
- model_type = self.model_type_to_use
- if model_type:
- payload['model_type'] = model_type
-
- if self.start_frame_input.text():
- payload['start_frame'] = self.start_frame_input.text()
- if self.end_frame_input.text():
- payload['end_frame'] = self.end_frame_input.text()
-
- if self.duration_input.text():
- payload['duration_sec'] = self.duration_input.text()
-
- payload['start_generation'] = self.autostart_checkbox.isChecked()
-
- self.status_updated.emit("AI Joiner: Sending parameters...")
- try:
- response = requests.post(f"{API_BASE_URL}/api/generate", json=payload)
- if response.status_code == 200:
- if payload['start_generation']:
- self.status_updated.emit("AI Joiner: Generation sent to server. Polling for output...")
- else:
- self.status_updated.emit("AI Joiner: Parameters set. Polling for manually started generation...")
- self.start_polling()
- else:
- self.handle_api_error(response, "setting parameters")
- except requests.exceptions.RequestException as e:
- self.status_updated.emit(f"AI Joiner Error: Connection error: {e}")
-
- def poll_for_output(self):
- try:
- response = requests.get(f"{API_BASE_URL}/api/outputs")
- if response.status_code == 200:
- data = response.json()
- output_files = data.get("outputs", [])
-
- new_files = set(output_files) - self.processed_files
- if new_files:
- self.status_updated.emit(f"AI Joiner: Received {len(new_files)} new clip(s).")
- for file_path in sorted(list(new_files)):
- self.add_result_item(file_path)
- self.processed_files.add(file_path)
- else:
- self.status_updated.emit("AI Joiner: Polling for output...")
- else:
- self.status_updated.emit("AI Joiner: Polling for output...")
- except requests.exceptions.RequestException:
- self.status_updated.emit("AI Joiner: Polling... (Connection issue)")
-
- def add_result_item(self, video_path):
- item_widget = VideoResultItemWidget(video_path, self.plugin)
- list_item = QListWidgetItem(self.results_list)
- list_item.setSizeHint(item_widget.sizeHint())
- self.results_list.addItem(list_item)
- self.results_list.setItemWidget(list_item, item_widget)
-
- def clear_results(self):
- self.results_list.clear()
- self.processed_files.clear()
-
-class Plugin(VideoEditorPlugin):
- def initialize(self):
- self.name = "AI Frame Joiner"
- self.description = "Uses a local AI server to generate a video between two frames."
- self.client_widget = WgpClientWidget(self)
- self.dock_widget = None
- self.active_region = None
- self.temp_dir = None
- self.insert_on_new_track = False
- self.client_widget.status_updated.connect(self.update_main_status)
-
- def enable(self):
- if not self.dock_widget:
- self.dock_widget = self.app.add_dock_widget(self, self.client_widget, "AI Frame Joiner", show_on_creation=False)
-
- self.dock_widget.hide()
-
- self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu)
-
- def disable(self):
- try:
- self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu)
- except TypeError:
- pass
- self._cleanup_temp_dir()
- self.client_widget.stop_polling()
-
- def update_main_status(self, message):
- self.app.status_label.setText(message)
-
- def _cleanup_temp_dir(self):
- if self.temp_dir and os.path.exists(self.temp_dir):
- shutil.rmtree(self.temp_dir)
- self.temp_dir = None
-
- def _reset_state(self):
- self.active_region = None
- self.insert_on_new_track = False
- self.client_widget.stop_polling()
- self.client_widget.clear_results()
- self._cleanup_temp_dir()
- self.update_main_status("AI Joiner: Idle")
- self.client_widget.set_previews(None, None)
- self.client_widget.model_type_to_use = None
-
- def on_timeline_context_menu(self, menu, event):
- region = self.app.timeline_widget.get_region_at_pos(event.pos())
- if region:
- menu.addSeparator()
-
- start_sec, end_sec = region
- start_data, _, _ = self.app.get_frame_data_at_time(start_sec)
- end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
-
- if start_data and end_data:
- join_action = menu.addAction("Join Frames With AI")
- join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
-
- join_action_new_track = menu.addAction("Join Frames With AI (New Track)")
- join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
-
- create_action = menu.addAction("Create Frames With AI")
- create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False))
-
- create_action_new_track = menu.addAction("Create Frames With AI (New Track)")
- create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True))
-
- def setup_generator_for_region(self, region, on_new_track=False):
- self._reset_state()
- self.client_widget.model_type_to_use = "i2v_2_2"
- self.active_region = region
- self.insert_on_new_track = on_new_track
- self.client_widget.previews_widget.setVisible(True)
-
- start_sec, end_sec = region
- if not self.client_widget.check_server_status():
- return
-
- start_data, w, h = self.app.get_frame_data_at_time(start_sec)
- end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
-
- if not start_data or not end_data:
- QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames for the selected region.")
- return
-
- preview_size = QSize(160, 90)
- start_pixmap, end_pixmap = None, None
-
- try:
- start_img = QImage(start_data, w, h, QImage.Format.Format_RGB888)
- start_pixmap = QPixmap.fromImage(start_img).scaled(preview_size, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
-
- end_img = QImage(end_data, w, h, QImage.Format.Format_RGB888)
- end_pixmap = QPixmap.fromImage(end_img).scaled(preview_size, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
- except Exception as e:
- QMessageBox.critical(self.app, "Image Error", f"Could not create preview images: {e}")
-
- self.client_widget.set_previews(start_pixmap, end_pixmap)
-
- try:
- self.temp_dir = tempfile.mkdtemp(prefix="ai_joiner_")
- start_img_path = os.path.join(self.temp_dir, "start_frame.png")
- end_img_path = os.path.join(self.temp_dir, "end_frame.png")
-
- QImage(start_data, w, h, QImage.Format.Format_RGB888).save(start_img_path)
- QImage(end_data, w, h, QImage.Format.Format_RGB888).save(end_img_path)
- except Exception as e:
- QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}")
- self._cleanup_temp_dir()
- return
-
- duration_sec = end_sec - start_sec
- self.client_widget.duration_input.setText(str(duration_sec))
-
- self.client_widget.start_frame_input.setText(start_img_path)
- self.client_widget.end_frame_input.setText(end_img_path)
- self.update_main_status(f"AI Joiner: Ready for region {start_sec:.2f}s - {end_sec:.2f}s")
-
- self.dock_widget.show()
- self.dock_widget.raise_()
-
- def setup_creator_for_region(self, region, on_new_track=False):
- self._reset_state()
- self.client_widget.model_type_to_use = "t2v_2_2"
- self.active_region = region
- self.insert_on_new_track = on_new_track
- self.client_widget.previews_widget.setVisible(False)
-
- start_sec, end_sec = region
- if not self.client_widget.check_server_status():
- return
-
- self.client_widget.set_previews(None, None)
-
- duration_sec = end_sec - start_sec
- self.client_widget.duration_input.setText(str(duration_sec))
-
- self.client_widget.start_frame_input.clear()
- self.client_widget.end_frame_input.clear()
-
- self.update_main_status(f"AI Creator: Ready for region {start_sec:.2f}s - {end_sec:.2f}s")
-
- self.dock_widget.show()
- self.dock_widget.raise_()
-
- def insert_generated_clip(self, video_path):
- from main import TimelineClip
-
- if not self.active_region:
- self.update_main_status("AI Joiner Error: No active region to insert into.")
- return
-
- if not os.path.exists(video_path):
- self.update_main_status(f"AI Joiner Error: Output file not found: {video_path}")
- return
-
- start_sec, end_sec = self.active_region
- is_new_track_mode = self.insert_on_new_track
-
- self.update_main_status(f"AI Joiner: Inserting clip {os.path.basename(video_path)}")
-
- def complex_insertion_action():
- probe = ffmpeg.probe(video_path)
- actual_duration = float(probe['format']['duration'])
-
- if is_new_track_mode:
- self.app.timeline.num_video_tracks += 1
- new_track_index = self.app.timeline.num_video_tracks
-
- new_clip = TimelineClip(
- source_path=video_path,
- timeline_start_sec=start_sec,
- clip_start_sec=0,
- duration_sec=actual_duration,
- track_index=new_track_index,
- track_type='video',
- media_type='video',
- group_id=str(uuid.uuid4())
- )
- self.app.timeline.add_clip(new_clip)
- else:
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
-
- clips_to_remove = [
- c for c in self.app.timeline.clips
- if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec
- ]
- for clip in clips_to_remove:
- if clip in self.app.timeline.clips:
- self.app.timeline.clips.remove(clip)
-
- new_clip = TimelineClip(
- source_path=video_path,
- timeline_start_sec=start_sec,
- clip_start_sec=0,
- duration_sec=actual_duration,
- track_index=1,
- track_type='video',
- media_type='video',
- group_id=str(uuid.uuid4())
- )
- self.app.timeline.add_clip(new_clip)
-
- try:
- self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action)
-
- self.app.prune_empty_tracks()
- self.update_main_status("AI clip inserted successfully.")
-
- for i in range(self.client_widget.results_list.count()):
- item = self.client_widget.results_list.item(i)
- widget = self.client_widget.results_list.itemWidget(item)
- if widget and widget.video_path == video_path:
- self.client_widget.results_list.takeItem(i)
- break
-
- except Exception as e:
- error_message = f"AI Joiner Error during clip insertion/probing: {e}"
- self.update_main_status(error_message)
- print(error_message)
\ No newline at end of file
From 4f9c03bf4d1546bfc0837f43bdf2dd4e321931e6 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 21:19:09 +1100
Subject: [PATCH 119/155] update screenshots
---
README.md | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 3726bd212..8f30a9927 100644
--- a/README.md
+++ b/README.md
@@ -42,10 +42,9 @@ python videoeditor.py
```
## Screenshots
-
-
-
-
+
+
+
## Credits
The AI Video Generator plugin is built from a desktop port of WAN2GP by DeepBeepMeep.
From 1622655827b1f904ccaf104807013a68def53739 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 21:27:19 +1100
Subject: [PATCH 120/155] update install instructions
---
README.md | 20 ++++++++++++++++++--
1 file changed, 18 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 8f30a9927..15764d5d3 100644
--- a/README.md
+++ b/README.md
@@ -24,16 +24,32 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi
## Installation
+
+#### **Windows**
+
```bash
git clone https://github.com/Tophness/Wan2GP.git
cd Wan2GP
git checkout video_editor
-conda create -n wan2gp python=3.10.9
-conda activate wan2gp
+python -m venv venv
+venv\Scripts\activate
+pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
+pip install -r requirements.txt
+```
+
+#### **Linux / macOS**
+
+```
+git clone https://github.com/Tophness/Wan2GP.git
+cd Wan2GP
+git checkout video_editor
+python3 -m venv venv
+source venv/bin/activate
pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
pip install -r requirements.txt
```
+
## Usage
**Run the video editor:**
From 4aded7a158fab544fc66f607247e3211d468c333 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 22:15:37 +1100
Subject: [PATCH 121/155] Update README.md
---
README.md | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index 15764d5d3..712fbfba4 100644
--- a/README.md
+++ b/README.md
@@ -59,10 +59,11 @@ python videoeditor.py
## Screenshots
-
-
+
+
+
## Credits
The AI Video Generator plugin is built from a desktop port of WAN2GP by DeepBeepMeep.
See WAN2GP for more details.
-https://github.com/deepbeepmeep/Wan2GP
\ No newline at end of file
+https://github.com/deepbeepmeep/Wan2GP
From 8c9286cbc4b80ae9a167b1317ff30e69a339d6f4 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 22:19:14 +1100
Subject: [PATCH 122/155] Update README.md
---
README.md | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 712fbfba4..1fc0c3473 100644
--- a/README.md
+++ b/README.md
@@ -59,8 +59,9 @@ python videoeditor.py
## Screenshots
-
-
+
+
+
## Credits
From f7d62b539a9d0aa6f69dfc81ab72dbc246064647 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 22:22:45 +1100
Subject: [PATCH 123/155] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 1fc0c3473..eff58e558 100644
--- a/README.md
+++ b/README.md
@@ -58,7 +58,7 @@ python videoeditor.py
```
## Screenshots
-
+
From 0953b56c0fc26a5aa737054139b517f2f591c568 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 23:31:21 +1100
Subject: [PATCH 124/155] Add create to new track option back
---
plugins/wan2gp/main.py | 3417 +++++++++++++++++++++-------------------
1 file changed, 1764 insertions(+), 1653 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index ca201dcda..0b925227c 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -1,1653 +1,1764 @@
-import sys
-import os
-import threading
-import time
-import json
-import tempfile
-import shutil
-import uuid
-from unittest.mock import MagicMock
-from pathlib import Path
-
-# --- Start of Gradio Hijacking ---
-# This block creates a mock Gradio module. When wgp.py is imported,
-# all calls to `gr.*` will be intercepted by these mock objects,
-# preventing any UI from being built and allowing us to use the
-# backend logic directly.
-
-class MockGradioComponent(MagicMock):
- """A smarter mock that captures constructor arguments."""
- def __init__(self, *args, **kwargs):
- super().__init__(name=f"gr.{kwargs.get('elem_id', 'component')}")
- self.kwargs = kwargs
- self.value = kwargs.get('value')
- self.choices = kwargs.get('choices')
-
- for method in ['then', 'change', 'click', 'input', 'select', 'upload', 'mount', 'launch', 'on', 'release']:
- setattr(self, method, lambda *a, **kw: self)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- pass
-
-class MockGradioError(Exception):
- pass
-
-class MockGradioModule:
- def __getattr__(self, name):
- if name == 'Error':
- return lambda *args, **kwargs: MockGradioError(*args)
-
- if name in ['Info', 'Warning']:
- return lambda *args, **kwargs: print(f"Intercepted gr.{name}:", *args)
-
- return lambda *args, **kwargs: MockGradioComponent(*args, **kwargs)
-
-sys.modules['gradio'] = MockGradioModule()
-sys.modules['gradio.gallery'] = MockGradioModule()
-sys.modules['shared.gradio.gallery'] = MockGradioModule()
-# --- End of Gradio Hijacking ---
-
-import ffmpeg
-
-from PyQt6.QtWidgets import (
- QWidget, QVBoxLayout, QHBoxLayout, QTabWidget,
- QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox,
- QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem,
- QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem,
- QMessageBox, QRadioButton, QSizePolicy
-)
-from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QSize, QRectF
-from PyQt6.QtGui import QPixmap, QImage, QDropEvent
-from PyQt6.QtMultimedia import QMediaPlayer
-from PyQt6.QtMultimediaWidgets import QVideoWidget
-from PIL.ImageQt import ImageQt
-
-# Import the base plugin class from the main application's path
-sys.path.append(str(Path(__file__).parent.parent.parent))
-from plugins import VideoEditorPlugin
-
-class VideoResultItemWidget(QWidget):
- """A widget to display a generated video with a hover-to-play preview and insert button."""
- def __init__(self, video_path, plugin, parent=None):
- super().__init__(parent)
- self.video_path = video_path
- self.plugin = plugin
- self.app = plugin.app
- self.duration = 0.0
- self.has_audio = False
-
- self.setMinimumSize(200, 180)
- self.setMaximumHeight(190)
-
- layout = QVBoxLayout(self)
- layout.setContentsMargins(5, 5, 5, 5)
- layout.setSpacing(5)
-
- self.media_player = QMediaPlayer()
- self.video_widget = QVideoWidget()
- self.video_widget.setFixedSize(160, 90)
- self.media_player.setVideoOutput(self.video_widget)
- self.media_player.setSource(QUrl.fromLocalFile(self.video_path))
- self.media_player.setLoops(QMediaPlayer.Loops.Infinite)
-
- self.info_label = QLabel(os.path.basename(video_path))
- self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
- self.info_label.setWordWrap(True)
-
- self.insert_button = QPushButton("Insert into Timeline")
- self.insert_button.clicked.connect(self.on_insert)
-
- h_layout = QHBoxLayout()
- h_layout.addStretch()
- h_layout.addWidget(self.video_widget)
- h_layout.addStretch()
-
- layout.addLayout(h_layout)
- layout.addWidget(self.info_label)
- layout.addWidget(self.insert_button)
- self.probe_video()
-
- def probe_video(self):
- try:
- probe = ffmpeg.probe(self.video_path)
- self.duration = float(probe['format']['duration'])
- self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', []))
- self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)")
- except Exception as e:
- self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}")
- print(f"Error probing video {self.video_path}: {e}")
-
- def enterEvent(self, event):
- super().enterEvent(event)
- self.media_player.play()
- if not self.plugin.active_region or self.duration == 0: return
- start_sec, _ = self.plugin.active_region
- timeline = self.app.timeline_widget
- video_rect, audio_rect = None, None
- x = timeline.sec_to_x(start_sec)
- w = int(self.duration * timeline.pixels_per_second)
- if self.plugin.insert_on_new_track:
- video_y = timeline.TIMESCALE_HEIGHT
- video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
- if self.has_audio:
- audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT
- audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
- else:
- v_track_idx = 1
- visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx
- video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT
- video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
- if self.has_audio:
- a_track_idx = 1
- visual_a_idx = a_track_idx - 1
- audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT
- audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
- timeline.set_hover_preview_rects(video_rect, audio_rect)
-
- def leaveEvent(self, event):
- super().leaveEvent(event)
- self.media_player.pause()
- self.media_player.setPosition(0)
- self.app.timeline_widget.set_hover_preview_rects(None, None)
-
- def on_insert(self):
- self.media_player.stop()
- self.media_player.setSource(QUrl())
- self.media_player.setVideoOutput(None)
- self.app.timeline_widget.set_hover_preview_rects(None, None)
- self.plugin.insert_generated_clip(self.video_path)
-
-
-class QueueTableWidget(QTableWidget):
- rowsMoved = pyqtSignal(int, int)
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.setDragEnabled(True)
- self.setAcceptDrops(True)
- self.setDropIndicatorShown(True)
- self.setDragDropMode(self.DragDropMode.InternalMove)
- self.setSelectionBehavior(self.SelectionBehavior.SelectRows)
- self.setSelectionMode(self.SelectionMode.SingleSelection)
-
- def dropEvent(self, event: QDropEvent):
- if event.source() == self and event.dropAction() == Qt.DropAction.MoveAction:
- source_row = self.currentRow()
- target_item = self.itemAt(event.position().toPoint())
- dest_row = target_item.row() if target_item else self.rowCount()
- if source_row < dest_row: dest_row -=1
- if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row)
- event.acceptProposedAction()
- else:
- super().dropEvent(event)
-
-
-class Worker(QObject):
- progress = pyqtSignal(list)
- status = pyqtSignal(str)
- preview = pyqtSignal(object)
- output = pyqtSignal()
- finished = pyqtSignal()
- error = pyqtSignal(str)
- def __init__(self, plugin, state):
- super().__init__()
- self.plugin = plugin
- self.wgp = plugin.wgp
- self.state = state
- self._is_running = True
- self._last_progress_phase = None
- self._last_preview = None
-
- def run(self):
- def generation_target():
- try:
- for _ in self.wgp.process_tasks(self.state):
- if self._is_running: self.output.emit()
- else: break
- except Exception as e:
- import traceback
- print("Error in generation thread:")
- traceback.print_exc()
- if "gradio.Error" in str(type(e)): self.error.emit(str(e))
- else: self.error.emit(f"An unexpected error occurred: {e}")
- finally:
- self._is_running = False
- gen_thread = threading.Thread(target=generation_target, daemon=True)
- gen_thread.start()
- while self._is_running:
- gen = self.state.get('gen', {})
- current_phase = gen.get("progress_phase")
- if current_phase and current_phase != self._last_progress_phase:
- self._last_progress_phase = current_phase
- phase_name, step = current_phase
- total_steps = gen.get("num_inference_steps", 1)
- high_level_status = gen.get("progress_status", "")
- status_msg = self.wgp.merge_status_context(high_level_status, phase_name)
- progress_args = [(step, total_steps), status_msg]
- self.progress.emit(progress_args)
- preview_img = gen.get('preview')
- if preview_img is not None and preview_img is not self._last_preview:
- self._last_preview = preview_img
- self.preview.emit(preview_img)
- gen['preview'] = None
- time.sleep(0.1)
- gen_thread.join()
- self.finished.emit()
-
-
-class WgpDesktopPluginWidget(QWidget):
- def __init__(self, plugin):
- super().__init__()
- self.plugin = plugin
- self.wgp = plugin.wgp
- self.widgets = {}
- self.state = {}
- self.worker = None
- self.thread = None
- self.lora_map = {}
- self.full_resolution_choices = []
- self.main_config = {}
- self.processed_files = set()
-
- self.load_main_config()
- self.setup_ui()
- self.apply_initial_config()
- self.connect_signals()
- self.init_wgp_state()
-
- def setup_ui(self):
- main_layout = QVBoxLayout(self)
- self.header_info = QLabel("Header Info")
- main_layout.addWidget(self.header_info)
- self.tabs = QTabWidget()
- main_layout.addWidget(self.tabs)
- self.setup_generator_tab()
- self.setup_config_tab()
-
- def create_widget(self, widget_class, name, *args, **kwargs):
- widget = widget_class(*args, **kwargs)
- self.widgets[name] = widget
- return widget
-
- def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1.0, precision=1):
- container = QWidget()
- hbox = QHBoxLayout(container)
- hbox.setContentsMargins(0, 0, 0, 0)
- slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal)
- slider.setRange(min_val, max_val)
- slider.setValue(int(initial_val * scale))
- value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}")
- value_label.setMinimumWidth(50)
- slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}"))
- hbox.addWidget(slider)
- hbox.addWidget(value_label)
- return container
-
- def _create_file_input(self, name, label_text):
- container = self.create_widget(QWidget, f"{name}_container")
- hbox = QHBoxLayout(container)
- hbox.setContentsMargins(0, 0, 0, 0)
- line_edit = self.create_widget(QLineEdit, name)
- line_edit.setPlaceholderText("No file selected or path pasted")
- button = QPushButton("Browse...")
- def open_dialog():
- if "refs" in name:
- filenames, _ = QFileDialog.getOpenFileNames(self, f"Select {label_text}")
- if filenames: line_edit.setText(";".join(filenames))
- else:
- filename, _ = QFileDialog.getOpenFileName(self, f"Select {label_text}")
- if filename: line_edit.setText(filename)
- button.clicked.connect(open_dialog)
- clear_button = QPushButton("X")
- clear_button.setFixedWidth(30)
- clear_button.clicked.connect(lambda: line_edit.clear())
- hbox.addWidget(QLabel(f"{label_text}:"))
- hbox.addWidget(line_edit, 1)
- hbox.addWidget(button)
- hbox.addWidget(clear_button)
- return container
-
- def setup_generator_tab(self):
- gen_tab = QWidget()
- self.tabs.addTab(gen_tab, "Video Generator")
- gen_layout = QHBoxLayout(gen_tab)
- left_panel = QWidget()
- left_layout = QVBoxLayout(left_panel)
- gen_layout.addWidget(left_panel, 1)
- right_panel = QWidget()
- right_layout = QVBoxLayout(right_panel)
- gen_layout.addWidget(right_panel, 1)
- scroll_area = QScrollArea()
- scroll_area.setWidgetResizable(True)
- left_layout.addWidget(scroll_area)
- options_widget = QWidget()
- scroll_area.setWidget(options_widget)
- options_layout = QVBoxLayout(options_widget)
- model_layout = QHBoxLayout()
- self.widgets['model_family'] = QComboBox()
- self.widgets['model_base_type_choice'] = QComboBox()
- self.widgets['model_choice'] = QComboBox()
- model_layout.addWidget(QLabel("Model:"))
- model_layout.addWidget(self.widgets['model_family'], 2)
- model_layout.addWidget(self.widgets['model_base_type_choice'], 3)
- model_layout.addWidget(self.widgets['model_choice'], 3)
- options_layout.addLayout(model_layout)
- options_layout.addWidget(QLabel("Prompt:"))
- self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100)
- options_layout.addWidget(self.widgets['prompt'])
- options_layout.addWidget(QLabel("Negative Prompt:"))
- self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60)
- options_layout.addWidget(self.widgets['negative_prompt'])
- basic_group = QGroupBox("Basic Options")
- basic_layout = QFormLayout(basic_group)
- res_container = QWidget()
- res_hbox = QHBoxLayout(res_container)
- res_hbox.setContentsMargins(0, 0, 0, 0)
- res_hbox.addWidget(self.create_widget(QComboBox, 'resolution_group'), 2)
- res_hbox.addWidget(self.create_widget(QComboBox, 'resolution'), 3)
- basic_layout.addRow("Resolution:", res_container)
- basic_layout.addRow("Video Length:", self._create_slider_with_label('video_length', 1, 737, 81, 1.0, 0))
- basic_layout.addRow("Inference Steps:", self._create_slider_with_label('num_inference_steps', 1, 100, 30, 1.0, 0))
- basic_layout.addRow("Seed:", self.create_widget(QLineEdit, 'seed', '-1'))
- options_layout.addWidget(basic_group)
- mode_options_group = QGroupBox("Generation Mode & Input Options")
- mode_options_layout = QVBoxLayout(mode_options_group)
- mode_hbox = QHBoxLayout()
- mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_t', "Text Prompt Only"))
- mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_s', "Start with Image"))
- mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_v', "Continue Video"))
- mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_l', "Continue Last Video"))
- self.widgets['mode_t'].setChecked(True)
- mode_options_layout.addLayout(mode_hbox)
- options_hbox = QHBoxLayout()
- options_hbox.addWidget(self.create_widget(QCheckBox, 'image_end_checkbox', "Use End Image"))
- options_hbox.addWidget(self.create_widget(QCheckBox, 'control_video_checkbox', "Use Control Video"))
- options_hbox.addWidget(self.create_widget(QCheckBox, 'ref_image_checkbox', "Use Reference Image(s)"))
- mode_options_layout.addLayout(options_hbox)
- options_layout.addWidget(mode_options_group)
- inputs_group = QGroupBox("Inputs")
- inputs_layout = QVBoxLayout(inputs_group)
- inputs_layout.addWidget(self._create_file_input('image_start', "Start Image"))
- inputs_layout.addWidget(self._create_file_input('image_end', "End Image"))
- inputs_layout.addWidget(self._create_file_input('video_source', "Source Video"))
- inputs_layout.addWidget(self._create_file_input('video_guide', "Control Video"))
- inputs_layout.addWidget(self._create_file_input('video_mask', "Video Mask"))
- inputs_layout.addWidget(self._create_file_input('image_refs', "Reference Image(s)"))
- denoising_row = QFormLayout()
- denoising_row.addRow("Denoising Strength:", self._create_slider_with_label('denoising_strength', 0, 100, 50, 100.0, 2))
- inputs_layout.addLayout(denoising_row)
- options_layout.addWidget(inputs_group)
- self.advanced_group = self.create_widget(QGroupBox, 'advanced_group', "Advanced Options")
- self.advanced_group.setCheckable(True)
- self.advanced_group.setChecked(False)
- advanced_layout = QVBoxLayout(self.advanced_group)
- advanced_tabs = self.create_widget(QTabWidget, 'advanced_tabs')
- advanced_layout.addWidget(advanced_tabs)
- self._setup_adv_tab_general(advanced_tabs)
- self._setup_adv_tab_loras(advanced_tabs)
- self._setup_adv_tab_speed(advanced_tabs)
- self._setup_adv_tab_postproc(advanced_tabs)
- self._setup_adv_tab_audio(advanced_tabs)
- self._setup_adv_tab_quality(advanced_tabs)
- self._setup_adv_tab_sliding_window(advanced_tabs)
- self._setup_adv_tab_misc(advanced_tabs)
- options_layout.addWidget(self.advanced_group)
-
- btn_layout = QHBoxLayout()
- self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate")
- self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue")
- self.generate_btn.setEnabled(True)
- self.add_to_queue_btn.setEnabled(False)
- btn_layout.addWidget(self.generate_btn)
- btn_layout.addWidget(self.add_to_queue_btn)
- right_layout.addLayout(btn_layout)
- self.status_label = self.create_widget(QLabel, 'status_label', "Idle")
- right_layout.addWidget(self.status_label)
- self.progress_bar = self.create_widget(QProgressBar, 'progress_bar')
- right_layout.addWidget(self.progress_bar)
- preview_group = self.create_widget(QGroupBox, 'preview_group', "Preview")
- preview_group.setCheckable(True)
- preview_group.setStyleSheet("QGroupBox { border: 1px solid #cccccc; }")
- preview_group_layout = QVBoxLayout(preview_group)
- self.preview_image = self.create_widget(QLabel, 'preview_image', "")
- self.preview_image.setAlignment(Qt.AlignmentFlag.AlignCenter)
- self.preview_image.setMinimumSize(200, 200)
- preview_group_layout.addWidget(self.preview_image)
- right_layout.addWidget(preview_group)
-
- results_group = QGroupBox("Generated Clips")
- results_group.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
- results_layout = QVBoxLayout(results_group)
- self.results_list = self.create_widget(QListWidget, 'results_list')
- self.results_list.setFlow(QListWidget.Flow.LeftToRight)
- self.results_list.setWrapping(True)
- self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust)
- self.results_list.setSpacing(10)
- results_layout.addWidget(self.results_list)
- right_layout.addWidget(results_group)
-
- right_layout.addWidget(QLabel("Queue:"))
- self.queue_table = self.create_widget(QueueTableWidget, 'queue_table')
- right_layout.addWidget(self.queue_table)
- queue_btn_layout = QHBoxLayout()
- self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected")
- self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue")
- self.abort_btn = self.create_widget(QPushButton, 'abort_btn', "Abort")
- queue_btn_layout.addWidget(self.remove_queue_btn)
- queue_btn_layout.addWidget(self.clear_queue_btn)
- queue_btn_layout.addWidget(self.abort_btn)
- right_layout.addLayout(queue_btn_layout)
-
- def _setup_adv_tab_general(self, tabs):
- tab = QWidget()
- tabs.addTab(tab, "General")
- layout = QFormLayout(tab)
- self.widgets['adv_general_layout'] = layout
- guidance_group = QGroupBox("Guidance")
- guidance_layout = self.create_widget(QFormLayout, 'guidance_layout', guidance_group)
- guidance_layout.addRow("Guidance (CFG):", self._create_slider_with_label('guidance_scale', 10, 200, 5.0, 10.0, 1))
- self.widgets['guidance_phases_row_index'] = guidance_layout.rowCount()
- guidance_layout.addRow("Guidance Phases:", self.create_widget(QComboBox, 'guidance_phases'))
- self.widgets['guidance2_row_index'] = guidance_layout.rowCount()
- guidance_layout.addRow("Guidance 2:", self._create_slider_with_label('guidance2_scale', 10, 200, 5.0, 10.0, 1))
- self.widgets['guidance3_row_index'] = guidance_layout.rowCount()
- guidance_layout.addRow("Guidance 3:", self._create_slider_with_label('guidance3_scale', 10, 200, 5.0, 10.0, 1))
- self.widgets['switch_thresh_row_index'] = guidance_layout.rowCount()
- guidance_layout.addRow("Switch Threshold:", self._create_slider_with_label('switch_threshold', 0, 1000, 0, 1.0, 0))
- layout.addRow(guidance_group)
- nag_group = self.create_widget(QGroupBox, 'nag_group', "NAG (Negative Adversarial Guidance)")
- nag_layout = QFormLayout(nag_group)
- nag_layout.addRow("NAG Scale:", self._create_slider_with_label('NAG_scale', 10, 200, 1.0, 10.0, 1))
- nag_layout.addRow("NAG Tau:", self._create_slider_with_label('NAG_tau', 10, 50, 3.5, 10.0, 1))
- nag_layout.addRow("NAG Alpha:", self._create_slider_with_label('NAG_alpha', 0, 20, 0.5, 10.0, 1))
- layout.addRow(nag_group)
- self.widgets['solver_row_container'] = QWidget()
- solver_hbox = QHBoxLayout(self.widgets['solver_row_container'])
- solver_hbox.setContentsMargins(0,0,0,0)
- solver_hbox.addWidget(QLabel("Sampler Solver:"))
- solver_hbox.addWidget(self.create_widget(QComboBox, 'sample_solver'))
- layout.addRow(self.widgets['solver_row_container'])
- self.widgets['flow_shift_row_index'] = layout.rowCount()
- layout.addRow("Shift Scale:", self._create_slider_with_label('flow_shift', 10, 250, 3.0, 10.0, 1))
- self.widgets['audio_guidance_row_index'] = layout.rowCount()
- layout.addRow("Audio Guidance:", self._create_slider_with_label('audio_guidance_scale', 10, 200, 4.0, 10.0, 1))
- self.widgets['repeat_generation_row_index'] = layout.rowCount()
- layout.addRow("Repeat Generations:", self._create_slider_with_label('repeat_generation', 1, 25, 1, 1.0, 0))
- combo = self.create_widget(QComboBox, 'multi_images_gen_type')
- combo.addItem("Generate all combinations", 0)
- combo.addItem("Match images and texts", 1)
- self.widgets['multi_images_gen_type_row_index'] = layout.rowCount()
- layout.addRow("Multi-Image Mode:", combo)
-
- def _setup_adv_tab_loras(self, tabs):
- tab = QWidget()
- tabs.addTab(tab, "Loras")
- layout = QVBoxLayout(tab)
- layout.addWidget(QLabel("Available Loras (Ctrl+Click to select multiple):"))
- lora_list = self.create_widget(QListWidget, 'activated_loras')
- lora_list.setSelectionMode(QListWidget.SelectionMode.MultiSelection)
- layout.addWidget(lora_list)
- layout.addWidget(QLabel("Loras Multipliers:"))
- layout.addWidget(self.create_widget(QTextEdit, 'loras_multipliers'))
-
- def _setup_adv_tab_speed(self, tabs):
- tab = QWidget()
- tabs.addTab(tab, "Speed")
- layout = QFormLayout(tab)
- combo = self.create_widget(QComboBox, 'skip_steps_cache_type')
- combo.addItem("None", "")
- combo.addItem("Tea Cache", "tea")
- combo.addItem("Mag Cache", "mag")
- layout.addRow("Cache Type:", combo)
- combo = self.create_widget(QComboBox, 'skip_steps_multiplier')
- combo.addItem("x1.5 speed up", 1.5)
- combo.addItem("x1.75 speed up", 1.75)
- combo.addItem("x2.0 speed up", 2.0)
- combo.addItem("x2.25 speed up", 2.25)
- combo.addItem("x2.5 speed up", 2.5)
- layout.addRow("Acceleration:", combo)
- layout.addRow("Start %:", self._create_slider_with_label('skip_steps_start_step_perc', 0, 100, 0, 1.0, 0))
-
- def _setup_adv_tab_postproc(self, tabs):
- tab = QWidget()
- tabs.addTab(tab, "Post-Processing")
- layout = QFormLayout(tab)
- combo = self.create_widget(QComboBox, 'temporal_upsampling')
- combo.addItem("Disabled", "")
- combo.addItem("Rife x2 frames/s", "rife2")
- combo.addItem("Rife x4 frames/s", "rife4")
- layout.addRow("Temporal Upsampling:", combo)
- combo = self.create_widget(QComboBox, 'spatial_upsampling')
- combo.addItem("Disabled", "")
- combo.addItem("Lanczos x1.5", "lanczos1.5")
- combo.addItem("Lanczos x2.0", "lanczos2")
- layout.addRow("Spatial Upsampling:", combo)
- layout.addRow("Film Grain Intensity:", self._create_slider_with_label('film_grain_intensity', 0, 100, 0, 100.0, 2))
- layout.addRow("Film Grain Saturation:", self._create_slider_with_label('film_grain_saturation', 0, 100, 0.5, 100.0, 2))
-
- def _setup_adv_tab_audio(self, tabs):
- tab = QWidget()
- tabs.addTab(tab, "Audio")
- layout = QFormLayout(tab)
- combo = self.create_widget(QComboBox, 'MMAudio_setting')
- combo.addItem("Disabled", 0)
- combo.addItem("Enabled", 1)
- layout.addRow("MMAudio:", combo)
- layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_prompt', placeholderText="MMAudio Prompt"))
- layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_neg_prompt', placeholderText="MMAudio Negative Prompt"))
- layout.addRow(self._create_file_input('audio_source', "Custom Soundtrack"))
-
- def _setup_adv_tab_quality(self, tabs):
- tab = QWidget()
- tabs.addTab(tab, "Quality")
- layout = QVBoxLayout(tab)
- slg_group = self.create_widget(QGroupBox, 'slg_group', "Skip Layer Guidance")
- slg_layout = QFormLayout(slg_group)
- slg_combo = self.create_widget(QComboBox, 'slg_switch')
- slg_combo.addItem("OFF", 0)
- slg_combo.addItem("ON", 1)
- slg_layout.addRow("Enable SLG:", slg_combo)
- slg_layout.addRow("Start %:", self._create_slider_with_label('slg_start_perc', 0, 100, 10, 1.0, 0))
- slg_layout.addRow("End %:", self._create_slider_with_label('slg_end_perc', 0, 100, 90, 1.0, 0))
- layout.addWidget(slg_group)
- quality_form = QFormLayout()
- self.widgets['quality_form_layout'] = quality_form
- apg_combo = self.create_widget(QComboBox, 'apg_switch')
- apg_combo.addItem("OFF", 0)
- apg_combo.addItem("ON", 1)
- self.widgets['apg_switch_row_index'] = quality_form.rowCount()
- quality_form.addRow("Adaptive Projected Guidance:", apg_combo)
- cfg_star_combo = self.create_widget(QComboBox, 'cfg_star_switch')
- cfg_star_combo.addItem("OFF", 0)
- cfg_star_combo.addItem("ON", 1)
- self.widgets['cfg_star_switch_row_index'] = quality_form.rowCount()
- quality_form.addRow("Classifier-Free Guidance Star:", cfg_star_combo)
- self.widgets['cfg_zero_step_row_index'] = quality_form.rowCount()
- quality_form.addRow("CFG Zero below Layer:", self._create_slider_with_label('cfg_zero_step', -1, 39, -1, 1.0, 0))
- combo = self.create_widget(QComboBox, 'min_frames_if_references')
- combo.addItem("Disabled (1 frame)", 1)
- combo.addItem("Generate 5 frames", 5)
- combo.addItem("Generate 9 frames", 9)
- combo.addItem("Generate 13 frames", 13)
- combo.addItem("Generate 17 frames", 17)
- self.widgets['min_frames_if_references_row_index'] = quality_form.rowCount()
- quality_form.addRow("Min Frames for Quality:", combo)
- layout.addLayout(quality_form)
-
- def _setup_adv_tab_sliding_window(self, tabs):
- tab = QWidget()
- self.widgets['sliding_window_tab_index'] = tabs.count()
- tabs.addTab(tab, "Sliding Window")
- layout = QFormLayout(tab)
- layout.addRow("Window Size:", self._create_slider_with_label('sliding_window_size', 5, 257, 129, 1.0, 0))
- layout.addRow("Overlap:", self._create_slider_with_label('sliding_window_overlap', 1, 97, 5, 1.0, 0))
- layout.addRow("Color Correction:", self._create_slider_with_label('sliding_window_color_correction_strength', 0, 100, 0, 100.0, 2))
- layout.addRow("Overlap Noise:", self._create_slider_with_label('sliding_window_overlap_noise', 0, 150, 20, 1.0, 0))
- layout.addRow("Discard Last Frames:", self._create_slider_with_label('sliding_window_discard_last_frames', 0, 20, 0, 1.0, 0))
-
- def _setup_adv_tab_misc(self, tabs):
- tab = QWidget()
- tabs.addTab(tab, "Misc")
- layout = QFormLayout(tab)
- self.widgets['misc_layout'] = layout
- riflex_combo = self.create_widget(QComboBox, 'RIFLEx_setting')
- riflex_combo.addItem("Auto", 0)
- riflex_combo.addItem("Always ON", 1)
- riflex_combo.addItem("Always OFF", 2)
- self.widgets['riflex_row_index'] = layout.rowCount()
- layout.addRow("RIFLEx Setting:", riflex_combo)
- fps_combo = self.create_widget(QComboBox, 'force_fps')
- layout.addRow("Force FPS:", fps_combo)
- profile_combo = self.create_widget(QComboBox, 'override_profile')
- profile_combo.addItem("Default Profile", -1)
- for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val)
- layout.addRow("Override Memory Profile:", profile_combo)
- combo = self.create_widget(QComboBox, 'multi_prompts_gen_type')
- combo.addItem("Generate new Video per line", 0)
- combo.addItem("Use line for new Sliding Window", 1)
- layout.addRow("Multi-Prompt Mode:", combo)
-
- def setup_config_tab(self):
- config_tab = QWidget()
- self.tabs.addTab(config_tab, "Configuration")
- main_layout = QVBoxLayout(config_tab)
- self.config_status_label = QLabel("Apply changes for them to take effect. Some may require a restart.")
- main_layout.addWidget(self.config_status_label)
- config_tabs = QTabWidget()
- main_layout.addWidget(config_tabs)
- config_tabs.addTab(self._create_general_config_tab(), "General")
- config_tabs.addTab(self._create_performance_config_tab(), "Performance")
- config_tabs.addTab(self._create_extensions_config_tab(), "Extensions")
- config_tabs.addTab(self._create_outputs_config_tab(), "Outputs")
- config_tabs.addTab(self._create_notifications_config_tab(), "Notifications")
- self.apply_config_btn = QPushButton("Apply Changes")
- self.apply_config_btn.clicked.connect(self._on_apply_config_changes)
- main_layout.addWidget(self.apply_config_btn)
-
- def _create_scrollable_form_tab(self):
- tab_widget = QWidget()
- scroll_area = QScrollArea()
- scroll_area.setWidgetResizable(True)
- layout = QVBoxLayout(tab_widget)
- layout.addWidget(scroll_area)
- content_widget = QWidget()
- form_layout = QFormLayout(content_widget)
- scroll_area.setWidget(content_widget)
- return tab_widget, form_layout
-
- def _create_config_combo(self, form_layout, label, key, choices, default_value):
- combo = QComboBox()
- for text, data in choices: combo.addItem(text, data)
- index = combo.findData(self.wgp.server_config.get(key, default_value))
- if index != -1: combo.setCurrentIndex(index)
- self.widgets[f'config_{key}'] = combo
- form_layout.addRow(label, combo)
-
- def _create_config_slider(self, form_layout, label, key, min_val, max_val, default_value, step=1):
- container = QWidget()
- hbox = QHBoxLayout(container)
- hbox.setContentsMargins(0,0,0,0)
- slider = QSlider(Qt.Orientation.Horizontal)
- slider.setRange(min_val, max_val)
- slider.setSingleStep(step)
- slider.setValue(self.wgp.server_config.get(key, default_value))
- value_label = QLabel(str(slider.value()))
- value_label.setMinimumWidth(40)
- slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v)))
- hbox.addWidget(slider)
- hbox.addWidget(value_label)
- self.widgets[f'config_{key}'] = slider
- form_layout.addRow(label, container)
-
- def _create_config_checklist(self, form_layout, label, key, choices, default_value):
- list_widget = QListWidget()
- list_widget.setMinimumHeight(100)
- current_values = self.wgp.server_config.get(key, default_value)
- for text, data in choices:
- item = QListWidgetItem(text)
- item.setData(Qt.ItemDataRole.UserRole, data)
- item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable)
- item.setCheckState(Qt.CheckState.Checked if data in current_values else Qt.CheckState.Unchecked)
- list_widget.addItem(item)
- self.widgets[f'config_{key}'] = list_widget
- form_layout.addRow(label, list_widget)
-
- def _create_config_textbox(self, form_layout, label, key, default_value, multi_line=False):
- if multi_line:
- textbox = QTextEdit(default_value)
- textbox.setAcceptRichText(False)
- else:
- textbox = QLineEdit(default_value)
- self.widgets[f'config_{key}'] = textbox
- form_layout.addRow(label, textbox)
-
- def _create_general_config_tab(self):
- tab, form = self._create_scrollable_form_tab()
- _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False)
- self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types)
- self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1)
- self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0)
- self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto")
- self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata")
- self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], [])
- self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5)
- self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0)
- self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1)
- checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths))
- checkpoints_textbox = QTextEdit()
- checkpoints_textbox.setPlainText(checkpoints_paths_text)
- checkpoints_textbox.setAcceptRichText(False)
- checkpoints_textbox.setMinimumHeight(60)
- self.widgets['config_checkpoints_paths'] = checkpoints_textbox
- form.addRow("Checkpoints Paths:", checkpoints_textbox)
- self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default")
- return tab
-
- def _create_performance_config_tab(self):
- tab, form = self._create_scrollable_form_tab()
- self._create_config_combo(form, "Transformer Quantization:", "transformer_quantization", [("Scaled Int8 (recommended)", "int8"), ("16-bit (no quantization)", "bf16")], "int8")
- self._create_config_combo(form, "Transformer Data Type:", "transformer_dtype_policy", [("Best Supported by Hardware", ""), ("FP16", "fp16"), ("BF16", "bf16")], "")
- self._create_config_combo(form, "Transformer Calculation:", "mixed_precision", [("16-bit only", "0"), ("Mixed 16/32-bit (better quality)", "1")], "0")
- self._create_config_combo(form, "Text Encoder:", "text_encoder_quantization", [("16-bit (more RAM, better quality)", "bf16"), ("8-bit (less RAM)", "int8")], "int8")
- self._create_config_combo(form, "VAE Precision:", "vae_precision", [("16-bit (faster, less VRAM)", "16"), ("32-bit (slower, better quality)", "32")], "16")
- self._create_config_combo(form, "Compile Transformer:", "compile", [("On (requires Triton)", "transformer"), ("Off", "")], "")
- self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl")
- self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0)
- self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1)
- self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM)
- self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100)
- release_ram_btn = QPushButton("Force Release Models from RAM")
- release_ram_btn.clicked.connect(self._on_release_ram)
- form.addRow(release_ram_btn)
- return tab
-
- def _create_extensions_config_tab(self):
- tab, form = self._create_scrollable_form_tab()
- self._create_config_combo(form, "Prompt Enhancer:", "enhancer_enabled", [("Off", 0), ("Florence 2 + Llama 3.2", 1), ("Florence 2 + Joy Caption (uncensored)", 2)], 0)
- self._create_config_combo(form, "Enhancer Mode:", "enhancer_mode", [("Automatic on Generate", 0), ("On Demand Only", 1)], 0)
- self._create_config_combo(form, "MMAudio:", "mmaudio_enabled", [("Off", 0), ("Enabled (unloaded after use)", 1), ("Enabled (persistent in RAM)", 2)], 0)
- return tab
-
- def _create_outputs_config_tab(self):
- tab, form = self._create_scrollable_form_tab()
- self._create_config_combo(form, "Video Codec:", "video_output_codec", [("x265 Balanced", 'libx265_28'), ("x264 Balanced", 'libx264_8'), ("x265 High Quality", 'libx265_8'), ("x264 High Quality", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], 'libx264_8')
- self._create_config_combo(form, "Image Codec:", "image_output_codec", [("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], 'jpeg_95')
- self._create_config_textbox(form, "Video Output Folder:", "save_path", "outputs")
- self._create_config_textbox(form, "Image Output Folder:", "image_save_path", "outputs")
- return tab
-
- def _create_notifications_config_tab(self):
- tab, form = self._create_scrollable_form_tab()
- self._create_config_combo(form, "Notification Sound:", "notification_sound_enabled", [("On", 1), ("Off", 0)], 0)
- self._create_config_slider(form, "Sound Volume:", "notification_sound_volume", 0, 100, 50, 5)
- return tab
-
- def init_wgp_state(self):
- wgp = self.wgp
- initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type)
- dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
- _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
- all_model_ids = [m[1] for m in all_models]
- if initial_model not in all_model_ids: initial_model = wgp.transformer_type
- state_dict = {}
- state_dict["model_filename"] = wgp.get_model_filename(initial_model, wgp.transformer_quantization, wgp.transformer_dtype_policy)
- state_dict["model_type"] = initial_model
- state_dict["advanced"] = wgp.advanced
- state_dict["last_model_per_family"] = wgp.server_config.get("last_model_per_family", {})
- state_dict["last_model_per_type"] = wgp.server_config.get("last_model_per_type", {})
- state_dict["last_resolution_per_group"] = wgp.server_config.get("last_resolution_per_group", {})
- state_dict["gen"] = {"queue": []}
- self.state = state_dict
- self.advanced_group.setChecked(wgp.advanced)
- self.update_model_dropdowns(initial_model)
- self.refresh_ui_from_model_change(initial_model)
- self._update_input_visibility()
-
- def update_model_dropdowns(self, current_model_type):
- wgp = self.wgp
- family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type)
- for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]:
- combo = self.widgets[combo_name]
- combo.blockSignals(True)
- combo.clear()
- if mock.choices:
- for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key)
- index = combo.findData(mock.value)
- if index != -1: combo.setCurrentIndex(index)
-
- is_visible = True
- if hasattr(mock, 'kwargs') and isinstance(mock.kwargs, dict):
- is_visible = mock.kwargs.get('visible', True)
- elif hasattr(mock, 'visible'):
- is_visible = mock.visible
- combo.setVisible(is_visible)
-
- combo.blockSignals(False)
-
- def refresh_ui_from_model_change(self, model_type):
- """Update UI controls with default settings when the model is changed."""
- wgp = self.wgp
- self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode))
- ui_defaults = wgp.get_default_settings(model_type)
- wgp.set_model_settings(self.state, model_type, ui_defaults)
-
- model_def = wgp.get_model_def(model_type)
- base_model_type = wgp.get_base_model_type(model_type)
- model_filename = self.state.get('model_filename', '')
-
- image_outputs = model_def.get("image_outputs", False)
- vace = wgp.test_vace_module(model_type)
- t2v = base_model_type in ['t2v', 't2v_2_2']
- i2v = wgp.test_class_i2v(model_type)
- fantasy = base_model_type in ["fantasy"]
- multitalk = model_def.get("multitalk_class", False)
- any_audio_guidance = fantasy or multitalk
- sliding_window_enabled = wgp.test_any_sliding_window(model_type)
- recammaster = base_model_type in ["recam_1.3B"]
- ltxv = "ltxv" in model_filename
- diffusion_forcing = "diffusion_forcing" in model_filename
- any_skip_layer_guidance = model_def.get("skip_layer_guidance", False)
- any_cfg_zero = model_def.get("cfg_zero", False)
- any_cfg_star = model_def.get("cfg_star", False)
- any_apg = model_def.get("adaptive_projected_guidance", False)
- v2i_switch_supported = model_def.get("v2i_switch_supported", False)
-
- self._update_generation_mode_visibility(model_def)
-
- for widget in self.widgets.values():
- if hasattr(widget, 'blockSignals'): widget.blockSignals(True)
-
- self.widgets['prompt'].setText(ui_defaults.get("prompt", ""))
- self.widgets['negative_prompt'].setText(ui_defaults.get("negative_prompt", ""))
- self.widgets['seed'].setText(str(ui_defaults.get("seed", -1)))
-
- video_length_val = ui_defaults.get("video_length", 81)
- self.widgets['video_length'].setValue(video_length_val)
- self.widgets['video_length_label'].setText(str(video_length_val))
-
- steps_val = ui_defaults.get("num_inference_steps", 30)
- self.widgets['num_inference_steps'].setValue(steps_val)
- self.widgets['num_inference_steps_label'].setText(str(steps_val))
-
- self.widgets['resolution_group'].blockSignals(True)
- self.widgets['resolution'].blockSignals(True)
-
- current_res_choice = ui_defaults.get("resolution")
- model_resolutions = model_def.get("resolutions", None)
- self.full_resolution_choices, current_res_choice = wgp.get_resolution_choices(current_res_choice, model_resolutions)
- available_groups, selected_group_resolutions, selected_group = wgp.group_resolutions(model_def, self.full_resolution_choices, current_res_choice)
-
- self.widgets['resolution_group'].clear()
- self.widgets['resolution_group'].addItems(available_groups)
- group_index = self.widgets['resolution_group'].findText(selected_group)
- if group_index != -1:
- self.widgets['resolution_group'].setCurrentIndex(group_index)
-
- self.widgets['resolution'].clear()
- for label, value in selected_group_resolutions:
- self.widgets['resolution'].addItem(label, value)
- res_index = self.widgets['resolution'].findData(current_res_choice)
- if res_index != -1:
- self.widgets['resolution'].setCurrentIndex(res_index)
-
- self.widgets['resolution_group'].blockSignals(False)
- self.widgets['resolution'].blockSignals(False)
-
- for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'image_refs', 'audio_source']:
- if name in self.widgets: self.widgets[name].clear()
-
- guidance_layout = self.widgets['guidance_layout']
- guidance_max = model_def.get("guidance_max_phases", 1)
- guidance_layout.setRowVisible(self.widgets['guidance_phases_row_index'], guidance_max > 1)
-
- adv_general_layout = self.widgets['adv_general_layout']
- adv_general_layout.setRowVisible(self.widgets['flow_shift_row_index'], not image_outputs)
- adv_general_layout.setRowVisible(self.widgets['audio_guidance_row_index'], any_audio_guidance)
- adv_general_layout.setRowVisible(self.widgets['repeat_generation_row_index'], not image_outputs)
- adv_general_layout.setRowVisible(self.widgets['multi_images_gen_type_row_index'], i2v)
-
- self.widgets['slg_group'].setVisible(any_skip_layer_guidance)
- quality_form_layout = self.widgets['quality_form_layout']
- quality_form_layout.setRowVisible(self.widgets['apg_switch_row_index'], any_apg)
- quality_form_layout.setRowVisible(self.widgets['cfg_star_switch_row_index'], any_cfg_star)
- quality_form_layout.setRowVisible(self.widgets['cfg_zero_step_row_index'], any_cfg_zero)
- quality_form_layout.setRowVisible(self.widgets['min_frames_if_references_row_index'], v2i_switch_supported and image_outputs)
-
- self.widgets['advanced_tabs'].setTabVisible(self.widgets['sliding_window_tab_index'], sliding_window_enabled and not image_outputs)
-
- misc_layout = self.widgets['misc_layout']
- misc_layout.setRowVisible(self.widgets['riflex_row_index'], not (recammaster or ltxv or diffusion_forcing))
-
- index = self.widgets['multi_images_gen_type'].findData(ui_defaults.get('multi_images_gen_type', 0))
- if index != -1: self.widgets['multi_images_gen_type'].setCurrentIndex(index)
-
- guidance_val = ui_defaults.get("guidance_scale", 5.0)
- self.widgets['guidance_scale'].setValue(int(guidance_val * 10))
- self.widgets['guidance_scale_label'].setText(f"{guidance_val:.1f}")
-
- guidance2_val = ui_defaults.get("guidance2_scale", 5.0)
- self.widgets['guidance2_scale'].setValue(int(guidance2_val * 10))
- self.widgets['guidance2_scale_label'].setText(f"{guidance2_val:.1f}")
-
- guidance3_val = ui_defaults.get("guidance3_scale", 5.0)
- self.widgets['guidance3_scale'].setValue(int(guidance3_val * 10))
- self.widgets['guidance3_scale_label'].setText(f"{guidance3_val:.1f}")
-
- self.widgets['guidance_phases'].clear()
- if guidance_max >= 1: self.widgets['guidance_phases'].addItem("One Phase", 1)
- if guidance_max >= 2: self.widgets['guidance_phases'].addItem("Two Phases", 2)
- if guidance_max >= 3: self.widgets['guidance_phases'].addItem("Three Phases", 3)
- index = self.widgets['guidance_phases'].findData(ui_defaults.get("guidance_phases", 1))
- if index != -1: self.widgets['guidance_phases'].setCurrentIndex(index)
-
- switch_thresh_val = ui_defaults.get("switch_threshold", 0)
- self.widgets['switch_threshold'].setValue(switch_thresh_val)
- self.widgets['switch_threshold_label'].setText(str(switch_thresh_val))
-
- nag_scale_val = ui_defaults.get('NAG_scale', 1.0)
- self.widgets['NAG_scale'].setValue(int(nag_scale_val * 10))
- self.widgets['NAG_scale_label'].setText(f"{nag_scale_val:.1f}")
-
- nag_tau_val = ui_defaults.get('NAG_tau', 3.5)
- self.widgets['NAG_tau'].setValue(int(nag_tau_val * 10))
- self.widgets['NAG_tau_label'].setText(f"{nag_tau_val:.1f}")
-
- nag_alpha_val = ui_defaults.get('NAG_alpha', 0.5)
- self.widgets['NAG_alpha'].setValue(int(nag_alpha_val * 10))
- self.widgets['NAG_alpha_label'].setText(f"{nag_alpha_val:.1f}")
-
- self.widgets['nag_group'].setVisible(vace or t2v or i2v)
-
- self.widgets['sample_solver'].clear()
- sampler_choices = model_def.get("sample_solvers", [])
- self.widgets['solver_row_container'].setVisible(bool(sampler_choices))
- if sampler_choices:
- for label, value in sampler_choices: self.widgets['sample_solver'].addItem(label, value)
- solver_val = ui_defaults.get('sample_solver', sampler_choices[0][1])
- index = self.widgets['sample_solver'].findData(solver_val)
- if index != -1: self.widgets['sample_solver'].setCurrentIndex(index)
-
- flow_val = ui_defaults.get("flow_shift", 3.0)
- self.widgets['flow_shift'].setValue(int(flow_val * 10))
- self.widgets['flow_shift_label'].setText(f"{flow_val:.1f}")
-
- audio_guidance_val = ui_defaults.get("audio_guidance_scale", 4.0)
- self.widgets['audio_guidance_scale'].setValue(int(audio_guidance_val * 10))
- self.widgets['audio_guidance_scale_label'].setText(f"{audio_guidance_val:.1f}")
-
- repeat_val = ui_defaults.get("repeat_generation", 1)
- self.widgets['repeat_generation'].setValue(repeat_val)
- self.widgets['repeat_generation_label'].setText(str(repeat_val))
-
- available_loras, _, _, _, _, _ = wgp.setup_loras(model_type, None, wgp.get_lora_dir(model_type), "")
- self.state['loras'] = available_loras
- self.lora_map = {os.path.basename(p): p for p in available_loras}
- lora_list_widget = self.widgets['activated_loras']
- lora_list_widget.clear()
- lora_list_widget.addItems(sorted(self.lora_map.keys()))
- selected_loras = ui_defaults.get('activated_loras', [])
- for i in range(lora_list_widget.count()):
- item = lora_list_widget.item(i)
- if any(item.text() == os.path.basename(p) for p in selected_loras): item.setSelected(True)
- self.widgets['loras_multipliers'].setText(ui_defaults.get('loras_multipliers', ''))
-
- skip_cache_val = ui_defaults.get('skip_steps_cache_type', "")
- index = self.widgets['skip_steps_cache_type'].findData(skip_cache_val)
- if index != -1: self.widgets['skip_steps_cache_type'].setCurrentIndex(index)
-
- skip_mult = ui_defaults.get('skip_steps_multiplier', 1.5)
- index = self.widgets['skip_steps_multiplier'].findData(skip_mult)
- if index != -1: self.widgets['skip_steps_multiplier'].setCurrentIndex(index)
-
- skip_perc_val = ui_defaults.get('skip_steps_start_step_perc', 0)
- self.widgets['skip_steps_start_step_perc'].setValue(skip_perc_val)
- self.widgets['skip_steps_start_step_perc_label'].setText(str(skip_perc_val))
-
- temp_up_val = ui_defaults.get('temporal_upsampling', "")
- index = self.widgets['temporal_upsampling'].findData(temp_up_val)
- if index != -1: self.widgets['temporal_upsampling'].setCurrentIndex(index)
-
- spat_up_val = ui_defaults.get('spatial_upsampling', "")
- index = self.widgets['spatial_upsampling'].findData(spat_up_val)
- if index != -1: self.widgets['spatial_upsampling'].setCurrentIndex(index)
-
- film_grain_i = ui_defaults.get('film_grain_intensity', 0)
- self.widgets['film_grain_intensity'].setValue(int(film_grain_i * 100))
- self.widgets['film_grain_intensity_label'].setText(f"{film_grain_i:.2f}")
-
- film_grain_s = ui_defaults.get('film_grain_saturation', 0.5)
- self.widgets['film_grain_saturation'].setValue(int(film_grain_s * 100))
- self.widgets['film_grain_saturation_label'].setText(f"{film_grain_s:.2f}")
-
- self.widgets['MMAudio_setting'].setCurrentIndex(ui_defaults.get('MMAudio_setting', 0))
- self.widgets['MMAudio_prompt'].setText(ui_defaults.get('MMAudio_prompt', ''))
- self.widgets['MMAudio_neg_prompt'].setText(ui_defaults.get('MMAudio_neg_prompt', ''))
-
- self.widgets['slg_switch'].setCurrentIndex(ui_defaults.get('slg_switch', 0))
- slg_start_val = ui_defaults.get('slg_start_perc', 10)
- self.widgets['slg_start_perc'].setValue(slg_start_val)
- self.widgets['slg_start_perc_label'].setText(str(slg_start_val))
- slg_end_val = ui_defaults.get('slg_end_perc', 90)
- self.widgets['slg_end_perc'].setValue(slg_end_val)
- self.widgets['slg_end_perc_label'].setText(str(slg_end_val))
-
- self.widgets['apg_switch'].setCurrentIndex(ui_defaults.get('apg_switch', 0))
- self.widgets['cfg_star_switch'].setCurrentIndex(ui_defaults.get('cfg_star_switch', 0))
-
- cfg_zero_val = ui_defaults.get('cfg_zero_step', -1)
- self.widgets['cfg_zero_step'].setValue(cfg_zero_val)
- self.widgets['cfg_zero_step_label'].setText(str(cfg_zero_val))
-
- min_frames_val = ui_defaults.get('min_frames_if_references', 1)
- index = self.widgets['min_frames_if_references'].findData(min_frames_val)
- if index != -1: self.widgets['min_frames_if_references'].setCurrentIndex(index)
-
- self.widgets['RIFLEx_setting'].setCurrentIndex(ui_defaults.get('RIFLEx_setting', 0))
-
- fps = wgp.get_model_fps(model_type)
- force_fps_choices = [
- (f"Model Default ({fps} fps)", ""), ("Auto", "auto"), ("Control Video fps", "control"),
- ("Source Video fps", "source"), ("15", "15"), ("16", "16"), ("23", "23"),
- ("24", "24"), ("25", "25"), ("30", "30")
- ]
- self.widgets['force_fps'].clear()
- for label, value in force_fps_choices: self.widgets['force_fps'].addItem(label, value)
- force_fps_val = ui_defaults.get('force_fps', "")
- index = self.widgets['force_fps'].findData(force_fps_val)
- if index != -1: self.widgets['force_fps'].setCurrentIndex(index)
-
- override_prof_val = ui_defaults.get('override_profile', -1)
- index = self.widgets['override_profile'].findData(override_prof_val)
- if index != -1: self.widgets['override_profile'].setCurrentIndex(index)
-
- self.widgets['multi_prompts_gen_type'].setCurrentIndex(ui_defaults.get('multi_prompts_gen_type', 0))
-
- denoising_val = ui_defaults.get("denoising_strength", 0.5)
- self.widgets['denoising_strength'].setValue(int(denoising_val * 100))
- self.widgets['denoising_strength_label'].setText(f"{denoising_val:.2f}")
-
- sw_size = ui_defaults.get("sliding_window_size", 129)
- self.widgets['sliding_window_size'].setValue(sw_size)
- self.widgets['sliding_window_size_label'].setText(str(sw_size))
-
- sw_overlap = ui_defaults.get("sliding_window_overlap", 5)
- self.widgets['sliding_window_overlap'].setValue(sw_overlap)
- self.widgets['sliding_window_overlap_label'].setText(str(sw_overlap))
-
- sw_color = ui_defaults.get("sliding_window_color_correction_strength", 0)
- self.widgets['sliding_window_color_correction_strength'].setValue(int(sw_color * 100))
- self.widgets['sliding_window_color_correction_strength_label'].setText(f"{sw_color:.2f}")
-
- sw_noise = ui_defaults.get("sliding_window_overlap_noise", 20)
- self.widgets['sliding_window_overlap_noise'].setValue(sw_noise)
- self.widgets['sliding_window_overlap_noise_label'].setText(str(sw_noise))
-
- sw_discard = ui_defaults.get("sliding_window_discard_last_frames", 0)
- self.widgets['sliding_window_discard_last_frames'].setValue(sw_discard)
- self.widgets['sliding_window_discard_last_frames_label'].setText(str(sw_discard))
-
- for widget in self.widgets.values():
- if hasattr(widget, 'blockSignals'): widget.blockSignals(False)
-
- self._update_dynamic_ui()
- self._update_input_visibility()
-
- def _update_dynamic_ui(self):
- phases = self.widgets['guidance_phases'].currentData() or 1
- guidance_layout = self.widgets['guidance_layout']
- guidance_layout.setRowVisible(self.widgets['guidance2_row_index'], phases >= 2)
- guidance_layout.setRowVisible(self.widgets['guidance3_row_index'], phases >= 3)
- guidance_layout.setRowVisible(self.widgets['switch_thresh_row_index'], phases >= 2)
-
- def _update_generation_mode_visibility(self, model_def):
- allowed = model_def.get("image_prompt_types_allowed", "")
- choices = []
- if "T" in allowed or not allowed: choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T"))
- if "S" in allowed: choices.append(("Start Video with Image", "S"))
- if "V" in allowed: choices.append(("Continue Video", "V"))
- if "L" in allowed: choices.append(("Continue Last Video", "L"))
- button_map = { "T": self.widgets['mode_t'], "S": self.widgets['mode_s'], "V": self.widgets['mode_v'], "L": self.widgets['mode_l'] }
- for btn in button_map.values(): btn.setVisible(False)
- allowed_values = [c[1] for c in choices]
- for label, value in choices:
- if value in button_map:
- btn = button_map[value]
- btn.setText(label)
- btn.setVisible(True)
- current_checked_value = next((value for value, btn in button_map.items() if btn.isChecked()), None)
- if current_checked_value is None or not button_map[current_checked_value].isVisible():
- if allowed_values: button_map[allowed_values[0]].setChecked(True)
- end_image_visible = "E" in allowed
- self.widgets['image_end_checkbox'].setVisible(end_image_visible)
- if not end_image_visible: self.widgets['image_end_checkbox'].setChecked(False)
- control_video_visible = model_def.get("guide_preprocessing") is not None
- self.widgets['control_video_checkbox'].setVisible(control_video_visible)
- if not control_video_visible: self.widgets['control_video_checkbox'].setChecked(False)
- ref_image_visible = model_def.get("image_ref_choices") is not None
- self.widgets['ref_image_checkbox'].setVisible(ref_image_visible)
- if not ref_image_visible: self.widgets['ref_image_checkbox'].setChecked(False)
-
- def _update_input_visibility(self):
- is_s_mode = self.widgets['mode_s'].isChecked()
- is_v_mode = self.widgets['mode_v'].isChecked()
- is_l_mode = self.widgets['mode_l'].isChecked()
- use_end = self.widgets['image_end_checkbox'].isChecked() and self.widgets['image_end_checkbox'].isVisible()
- use_control = self.widgets['control_video_checkbox'].isChecked() and self.widgets['control_video_checkbox'].isVisible()
- use_ref = self.widgets['ref_image_checkbox'].isChecked() and self.widgets['ref_image_checkbox'].isVisible()
- self.widgets['image_start_container'].setVisible(is_s_mode)
- self.widgets['video_source_container'].setVisible(is_v_mode)
- end_checkbox_enabled = is_s_mode or is_v_mode or is_l_mode
- self.widgets['image_end_checkbox'].setEnabled(end_checkbox_enabled)
- self.widgets['image_end_container'].setVisible(use_end and end_checkbox_enabled)
- self.widgets['video_guide_container'].setVisible(use_control)
- self.widgets['video_mask_container'].setVisible(use_control)
- self.widgets['image_refs_container'].setVisible(use_ref)
-
- def connect_signals(self):
- self.widgets['model_family'].currentIndexChanged.connect(self._on_family_changed)
- self.widgets['model_base_type_choice'].currentIndexChanged.connect(self._on_base_type_changed)
- self.widgets['model_choice'].currentIndexChanged.connect(self._on_model_changed)
- self.widgets['resolution_group'].currentIndexChanged.connect(self._on_resolution_group_changed)
- self.widgets['guidance_phases'].currentIndexChanged.connect(self._update_dynamic_ui)
- self.widgets['mode_t'].toggled.connect(self._update_input_visibility)
- self.widgets['mode_s'].toggled.connect(self._update_input_visibility)
- self.widgets['mode_v'].toggled.connect(self._update_input_visibility)
- self.widgets['mode_l'].toggled.connect(self._update_input_visibility)
- self.widgets['image_end_checkbox'].toggled.connect(self._update_input_visibility)
- self.widgets['control_video_checkbox'].toggled.connect(self._update_input_visibility)
- self.widgets['ref_image_checkbox'].toggled.connect(self._update_input_visibility)
- self.widgets['preview_group'].toggled.connect(self._on_preview_toggled)
- self.generate_btn.clicked.connect(self._on_generate)
- self.add_to_queue_btn.clicked.connect(self._on_add_to_queue)
- self.remove_queue_btn.clicked.connect(self._on_remove_selected_from_queue)
- self.clear_queue_btn.clicked.connect(self._on_clear_queue)
- self.abort_btn.clicked.connect(self._on_abort)
- self.queue_table.rowsMoved.connect(self._on_queue_rows_moved)
-
- def load_main_config(self):
- try:
- with open('main_config.json', 'r') as f: self.main_config = json.load(f)
- except (FileNotFoundError, json.JSONDecodeError): self.main_config = {'preview_visible': False}
-
- def save_main_config(self):
- try:
- with open('main_config.json', 'w') as f: json.dump(self.main_config, f, indent=4)
- except Exception as e: print(f"Error saving main_config.json: {e}")
-
- def apply_initial_config(self):
- is_visible = self.main_config.get('preview_visible', True)
- self.widgets['preview_group'].setChecked(is_visible)
- self.widgets['preview_image'].setVisible(is_visible)
-
- def _on_preview_toggled(self, checked):
- self.widgets['preview_image'].setVisible(checked)
- self.main_config['preview_visible'] = checked
- self.save_main_config()
-
- def _on_family_changed(self):
- family = self.widgets['model_family'].currentData()
- if not family or not self.state: return
- base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family)
-
- if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict):
- is_visible_base = base_type_mock.kwargs.get('visible', True)
- elif hasattr(base_type_mock, 'visible'):
- is_visible_base = base_type_mock.visible
- else:
- is_visible_base = True
-
- self.widgets['model_base_type_choice'].blockSignals(True)
- self.widgets['model_base_type_choice'].clear()
- if base_type_mock.choices:
- for label, value in base_type_mock.choices: self.widgets['model_base_type_choice'].addItem(label, value)
- self.widgets['model_base_type_choice'].setCurrentIndex(self.widgets['model_base_type_choice'].findData(base_type_mock.value))
- self.widgets['model_base_type_choice'].setVisible(is_visible_base)
- self.widgets['model_base_type_choice'].blockSignals(False)
-
- if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict):
- is_visible_choice = choice_mock.kwargs.get('visible', True)
- elif hasattr(choice_mock, 'visible'):
- is_visible_choice = choice_mock.visible
- else:
- is_visible_choice = True
-
- self.widgets['model_choice'].blockSignals(True)
- self.widgets['model_choice'].clear()
- if choice_mock.choices:
- for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value)
- self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value))
- self.widgets['model_choice'].setVisible(is_visible_choice)
- self.widgets['model_choice'].blockSignals(False)
-
- self._on_model_changed()
-
- def _on_base_type_changed(self):
- family = self.widgets['model_family'].currentData()
- base_type = self.widgets['model_base_type_choice'].currentData()
- if not family or not base_type or not self.state: return
- base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type)
-
- if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict):
- is_visible_choice = choice_mock.kwargs.get('visible', True)
- elif hasattr(choice_mock, 'visible'):
- is_visible_choice = choice_mock.visible
- else:
- is_visible_choice = True
-
- self.widgets['model_choice'].blockSignals(True)
- self.widgets['model_choice'].clear()
- if choice_mock.choices:
- for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value)
- self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value))
- self.widgets['model_choice'].setVisible(is_visible_choice)
- self.widgets['model_choice'].blockSignals(False)
- self._on_model_changed()
-
- def _on_model_changed(self):
- model_type = self.widgets['model_choice'].currentData()
- if not model_type or model_type == self.state.get('model_type'): return
- self.wgp.change_model(self.state, model_type)
- self.refresh_ui_from_model_change(model_type)
-
- def _on_resolution_group_changed(self):
- selected_group = self.widgets['resolution_group'].currentText()
- if not selected_group or not hasattr(self, 'full_resolution_choices'): return
- model_type = self.state['model_type']
- model_def = self.wgp.get_model_def(model_type)
- model_resolutions = model_def.get("resolutions", None)
- group_resolution_choices = []
- if model_resolutions is None:
- group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group]
- else: return
- last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "")
- if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices:
- last_resolution = group_resolution_choices[0][1]
- self.widgets['resolution'].blockSignals(True)
- self.widgets['resolution'].clear()
- for label, value in group_resolution_choices: self.widgets['resolution'].addItem(label, value)
- self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution))
- self.widgets['resolution'].blockSignals(False)
-
- def collect_inputs(self):
- full_inputs = self.wgp.get_current_model_settings(self.state).copy()
- full_inputs['lset_name'] = ""
- full_inputs['image_mode'] = 0
- expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, }
- for key, default_value in expected_keys.items():
- if key not in full_inputs: full_inputs[key] = default_value
- full_inputs['prompt'] = self.widgets['prompt'].toPlainText()
- full_inputs['negative_prompt'] = self.widgets['negative_prompt'].toPlainText()
- full_inputs['resolution'] = self.widgets['resolution'].currentData()
- full_inputs['video_length'] = self.widgets['video_length'].value()
- full_inputs['num_inference_steps'] = self.widgets['num_inference_steps'].value()
- full_inputs['seed'] = int(self.widgets['seed'].text())
- image_prompt_type = ""
- video_prompt_type = ""
- if self.widgets['mode_s'].isChecked(): image_prompt_type = 'S'
- elif self.widgets['mode_v'].isChecked(): image_prompt_type = 'V'
- elif self.widgets['mode_l'].isChecked(): image_prompt_type = 'L'
- if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): image_prompt_type += 'E'
- if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): video_prompt_type += 'V'
- if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): video_prompt_type += 'I'
- full_inputs['image_prompt_type'] = image_prompt_type
- full_inputs['video_prompt_type'] = video_prompt_type
- for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']:
- if name in self.widgets: full_inputs[name] = self.widgets[name].text() or None
- paths = self.widgets['image_refs'].text().split(';')
- full_inputs['image_refs'] = [p.strip() for p in paths if p.strip()] if paths and paths[0] else None
- full_inputs['denoising_strength'] = self.widgets['denoising_strength'].value() / 100.0
- if self.advanced_group.isChecked():
- full_inputs['guidance_scale'] = self.widgets['guidance_scale'].value() / 10.0
- full_inputs['guidance_phases'] = self.widgets['guidance_phases'].currentData()
- full_inputs['guidance2_scale'] = self.widgets['guidance2_scale'].value() / 10.0
- full_inputs['guidance3_scale'] = self.widgets['guidance3_scale'].value() / 10.0
- full_inputs['switch_threshold'] = self.widgets['switch_threshold'].value()
- full_inputs['NAG_scale'] = self.widgets['NAG_scale'].value() / 10.0
- full_inputs['NAG_tau'] = self.widgets['NAG_tau'].value() / 10.0
- full_inputs['NAG_alpha'] = self.widgets['NAG_alpha'].value() / 10.0
- full_inputs['sample_solver'] = self.widgets['sample_solver'].currentData()
- full_inputs['flow_shift'] = self.widgets['flow_shift'].value() / 10.0
- full_inputs['audio_guidance_scale'] = self.widgets['audio_guidance_scale'].value() / 10.0
- full_inputs['repeat_generation'] = self.widgets['repeat_generation'].value()
- full_inputs['multi_images_gen_type'] = self.widgets['multi_images_gen_type'].currentData()
- selected_items = self.widgets['activated_loras'].selectedItems()
- full_inputs['activated_loras'] = [self.lora_map[item.text()] for item in selected_items if item.text() in self.lora_map]
- full_inputs['loras_multipliers'] = self.widgets['loras_multipliers'].toPlainText()
- full_inputs['skip_steps_cache_type'] = self.widgets['skip_steps_cache_type'].currentData()
- full_inputs['skip_steps_multiplier'] = self.widgets['skip_steps_multiplier'].currentData()
- full_inputs['skip_steps_start_step_perc'] = self.widgets['skip_steps_start_step_perc'].value()
- full_inputs['temporal_upsampling'] = self.widgets['temporal_upsampling'].currentData()
- full_inputs['spatial_upsampling'] = self.widgets['spatial_upsampling'].currentData()
- full_inputs['film_grain_intensity'] = self.widgets['film_grain_intensity'].value() / 100.0
- full_inputs['film_grain_saturation'] = self.widgets['film_grain_saturation'].value() / 100.0
- full_inputs['MMAudio_setting'] = self.widgets['MMAudio_setting'].currentData()
- full_inputs['MMAudio_prompt'] = self.widgets['MMAudio_prompt'].text()
- full_inputs['MMAudio_neg_prompt'] = self.widgets['MMAudio_neg_prompt'].text()
- full_inputs['RIFLEx_setting'] = self.widgets['RIFLEx_setting'].currentData()
- full_inputs['force_fps'] = self.widgets['force_fps'].currentData()
- full_inputs['override_profile'] = self.widgets['override_profile'].currentData()
- full_inputs['multi_prompts_gen_type'] = self.widgets['multi_prompts_gen_type'].currentData()
- full_inputs['slg_switch'] = self.widgets['slg_switch'].currentData()
- full_inputs['slg_start_perc'] = self.widgets['slg_start_perc'].value()
- full_inputs['slg_end_perc'] = self.widgets['slg_end_perc'].value()
- full_inputs['apg_switch'] = self.widgets['apg_switch'].currentData()
- full_inputs['cfg_star_switch'] = self.widgets['cfg_star_switch'].currentData()
- full_inputs['cfg_zero_step'] = self.widgets['cfg_zero_step'].value()
- full_inputs['min_frames_if_references'] = self.widgets['min_frames_if_references'].currentData()
- full_inputs['sliding_window_size'] = self.widgets['sliding_window_size'].value()
- full_inputs['sliding_window_overlap'] = self.widgets['sliding_window_overlap'].value()
- full_inputs['sliding_window_color_correction_strength'] = self.widgets['sliding_window_color_correction_strength'].value() / 100.0
- full_inputs['sliding_window_overlap_noise'] = self.widgets['sliding_window_overlap_noise'].value()
- full_inputs['sliding_window_discard_last_frames'] = self.widgets['sliding_window_discard_last_frames'].value()
- return full_inputs
-
- def _prepare_state_for_generation(self):
- if 'gen' in self.state:
- self.state['gen'].pop('abort', None)
- self.state['gen'].pop('in_progress', None)
-
- def _on_generate(self):
- try:
- is_running = self.thread and self.thread.isRunning()
- self._add_task_to_queue()
- if not is_running: self.start_generation()
- except Exception as e:
- import traceback; traceback.print_exc()
-
- def _on_add_to_queue(self):
- try:
- self._add_task_to_queue()
- except Exception as e:
- import traceback; traceback.print_exc()
-
- def _add_task_to_queue(self):
- all_inputs = self.collect_inputs()
- for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None)
- all_inputs['state'] = self.state
- self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs)
- self.state["validate_success"] = 1
- self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type'])
- self.update_queue_table()
-
- def start_generation(self):
- if not self.state['gen']['queue']: return
- self._prepare_state_for_generation()
- self.generate_btn.setEnabled(False)
- self.add_to_queue_btn.setEnabled(True)
- self.thread = QThread()
- self.worker = Worker(self.plugin, self.state)
- self.worker.moveToThread(self.thread)
- self.thread.started.connect(self.worker.run)
- self.worker.finished.connect(self.thread.quit)
- self.worker.finished.connect(self.worker.deleteLater)
- self.thread.finished.connect(self.thread.deleteLater)
- self.thread.finished.connect(self.on_generation_finished)
- self.worker.status.connect(self.status_label.setText)
- self.worker.progress.connect(self.update_progress)
- self.worker.preview.connect(self.update_preview)
- self.worker.output.connect(self.update_queue_and_results)
- self.worker.error.connect(self.on_generation_error)
- self.thread.start()
- self.update_queue_table()
-
- def on_generation_finished(self):
- time.sleep(0.1)
- self.status_label.setText("Finished.")
- self.progress_bar.setValue(0)
- self.generate_btn.setEnabled(True)
- self.add_to_queue_btn.setEnabled(False)
- self.thread = None; self.worker = None
- self.update_queue_table()
-
- def on_generation_error(self, err_msg):
- QMessageBox.critical(self, "Generation Error", str(err_msg))
- self.on_generation_finished()
-
- def update_progress(self, data):
- if len(data) > 1 and isinstance(data[0], tuple):
- step, total = data[0]
- self.progress_bar.setMaximum(total)
- self.progress_bar.setValue(step)
- self.status_label.setText(str(data[1]))
- if step <= 1: self.update_queue_table()
- elif len(data) > 1: self.status_label.setText(str(data[1]))
-
- def update_preview(self, pil_image):
- if pil_image and self.widgets['preview_group'].isChecked():
- q_image = ImageQt(pil_image)
- pixmap = QPixmap.fromImage(q_image)
- self.preview_image.setPixmap(pixmap.scaled(self.preview_image.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation))
-
- def update_queue_and_results(self):
- self.update_queue_table()
- file_list = self.state.get('gen', {}).get('file_list', [])
- for file_path in file_list:
- if file_path not in self.processed_files:
- self.add_result_item(file_path)
- self.processed_files.add(file_path)
-
- def add_result_item(self, video_path):
- item_widget = VideoResultItemWidget(video_path, self.plugin)
- list_item = QListWidgetItem(self.results_list)
- list_item.setSizeHint(item_widget.sizeHint())
- self.results_list.addItem(list_item)
- self.results_list.setItemWidget(list_item, item_widget)
-
- def update_queue_table(self):
- with self.wgp.lock:
- queue = self.state.get('gen', {}).get('queue', [])
- is_running = self.thread and self.thread.isRunning()
- queue_to_display = queue if is_running else [None] + queue
- table_data = self.wgp.get_queue_table(queue_to_display)
- self.queue_table.setRowCount(0)
- self.queue_table.setRowCount(len(table_data))
- self.queue_table.setColumnCount(4)
- self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"])
- for row_idx, row_data in enumerate(table_data):
- prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1])
- for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]):
- self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data)))
- self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
- self.queue_table.resizeColumnsToContents()
-
- def _on_remove_selected_from_queue(self):
- selected_row = self.queue_table.currentRow()
- if selected_row < 0: return
- with self.wgp.lock:
- is_running = self.thread and self.thread.isRunning()
- offset = 1 if is_running else 0
- queue = self.state.get('gen', {}).get('queue', [])
- if len(queue) > selected_row + offset: queue.pop(selected_row + offset)
- self.update_queue_table()
-
- def _on_queue_rows_moved(self, source_row, dest_row):
- with self.wgp.lock:
- queue = self.state.get('gen', {}).get('queue', [])
- is_running = self.thread and self.thread.isRunning()
- offset = 1 if is_running else 0
- real_source_idx = source_row + offset
- real_dest_idx = dest_row + offset
- moved_item = queue.pop(real_source_idx)
- queue.insert(real_dest_idx, moved_item)
- self.update_queue_table()
-
- def _on_clear_queue(self):
- self.wgp.clear_queue_action(self.state)
- self.update_queue_table()
-
- def _on_abort(self):
- if self.worker:
- self.wgp.abort_generation(self.state)
- self.status_label.setText("Aborting...")
- self.worker._is_running = False
-
- def _on_release_ram(self):
- self.wgp.release_RAM()
- QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.")
-
- def _on_apply_config_changes(self):
- changes = {}
- list_widget = self.widgets['config_transformer_types']
- changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
- list_widget = self.widgets['config_preload_model_policy']
- changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
- changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData()
- changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText()
- for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]:
- changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
- for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]:
- changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
- changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value()
- for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]:
- changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
- for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]:
- widget = self.widgets[f'config_{key}']
- changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text()
- changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData()
- changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value()
- changes['last_resolution_choice'] = self.widgets['resolution'].currentData()
- try:
- msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes)
- self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.")
- self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode))
- if family_mock.choices is not None or choice_mock.choices is not None:
- self.update_model_dropdowns(self.wgp.transformer_type)
- self.refresh_ui_from_model_change(self.wgp.transformer_type)
- except Exception as e:
- self.config_status_label.setText(f"Error applying changes: {e}")
- import traceback; traceback.print_exc()
-
-class Plugin(VideoEditorPlugin):
- def initialize(self):
- self.name = "WanGP AI Generator"
- self.description = "Uses the integrated WanGP library to generate video clips."
- self.client_widget = WgpDesktopPluginWidget(self)
- self.dock_widget = None
- self.active_region = None; self.temp_dir = None
- self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None
-
- def enable(self):
- if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name)
- self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu)
- self.app.status_label.setText(f"{self.name}: Enabled.")
-
- def disable(self):
- try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu)
- except TypeError: pass
- self._cleanup_temp_dir()
- if self.client_widget.worker: self.client_widget._on_abort()
- self.app.status_label.setText(f"{self.name}: Disabled.")
-
- def _cleanup_temp_dir(self):
- if self.temp_dir and os.path.exists(self.temp_dir):
- shutil.rmtree(self.temp_dir)
- self.temp_dir = None
-
- def _reset_state(self):
- self.active_region = None; self.insert_on_new_track = False
- self.start_frame_path = None; self.end_frame_path = None
- self.client_widget.processed_files.clear()
- self.client_widget.results_list.clear()
- self.client_widget.widgets['image_start'].clear()
- self.client_widget.widgets['image_end'].clear()
- self.client_widget.widgets['video_source'].clear()
- self._cleanup_temp_dir()
- self.app.status_label.setText(f"{self.name}: Ready.")
-
- def on_timeline_context_menu(self, menu, event):
- region = self.app.timeline_widget.get_region_at_pos(event.pos())
- if region:
- menu.addSeparator()
- start_sec, end_sec = region
- start_data, _, _ = self.app.get_frame_data_at_time(start_sec)
- end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
- if start_data and end_data:
- join_action = menu.addAction("Join Frames With WanGP")
- join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
- join_action_new_track = menu.addAction("Join Frames With WanGP (New Track)")
- join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
- create_action = menu.addAction("Create Video With WanGP")
- create_action.triggered.connect(lambda: self.setup_creator_for_region(region))
-
- def setup_generator_for_region(self, region, on_new_track=False):
- self._reset_state()
- self.active_region = region
- self.insert_on_new_track = on_new_track
-
- model_to_set = 'i2v_2_2'
- dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
- _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
- if any(model_to_set == m[1] for m in all_models):
- if self.client_widget.state.get('model_type') != model_to_set:
- self.client_widget.update_model_dropdowns(model_to_set)
- self.client_widget._on_model_changed()
- else:
- print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
-
- start_sec, end_sec = region
- start_data, w, h = self.app.get_frame_data_at_time(start_sec)
- end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
- if not start_data or not end_data:
- QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.")
- return
- try:
- self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_")
- self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png")
- self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png")
- QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path)
- QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
-
- duration_sec = end_sec - start_sec
- wgp = self.client_widget.wgp
- model_type = self.client_widget.state['model_type']
- fps = wgp.get_model_fps(model_type)
- video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
-
- self.client_widget.widgets['video_length'].setValue(video_length_frames)
- self.client_widget.widgets['mode_s'].setChecked(True)
- self.client_widget.widgets['image_end_checkbox'].setChecked(True)
- self.client_widget.widgets['image_start'].setText(self.start_frame_path)
- self.client_widget.widgets['image_end'].setText(self.end_frame_path)
-
- self.client_widget._update_input_visibility()
-
- except Exception as e:
- QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}")
- self._cleanup_temp_dir()
- return
- self.app.status_label.setText(f"WanGP: Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.")
- self.dock_widget.show()
- self.dock_widget.raise_()
-
- def setup_creator_for_region(self, region):
- self._reset_state()
- self.active_region = region
- self.insert_on_new_track = True
-
- model_to_set = 't2v_2_2'
- dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
- _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
- if any(model_to_set == m[1] for m in all_models):
- if self.client_widget.state.get('model_type') != model_to_set:
- self.client_widget.update_model_dropdowns(model_to_set)
- self.client_widget._on_model_changed()
- else:
- print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.")
-
- start_sec, end_sec = region
- duration_sec = end_sec - start_sec
- wgp = self.client_widget.wgp
- model_type = self.client_widget.state['model_type']
- fps = wgp.get_model_fps(model_type)
- video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
-
- self.client_widget.widgets['video_length'].setValue(video_length_frames)
- self.client_widget.widgets['mode_t'].setChecked(True)
-
- self.app.status_label.setText(f"WanGP: Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.")
- self.dock_widget.show()
- self.dock_widget.raise_()
-
- def insert_generated_clip(self, video_path):
- from videoeditor import TimelineClip
- if not self.active_region:
- self.app.status_label.setText("WanGP Error: No active region to insert into."); return
- if not os.path.exists(video_path):
- self.app.status_label.setText(f"WanGP Error: Output file not found: {video_path}"); return
- start_sec, end_sec = self.active_region
- def complex_insertion_action():
- self.app._add_media_files_to_project([video_path])
- media_info = self.app.media_properties.get(video_path)
- if not media_info: raise ValueError("Could not probe inserted clip.")
- actual_duration, has_audio = media_info['duration'], media_info['has_audio']
- if self.insert_on_new_track:
- self.app.timeline.num_video_tracks += 1
- video_track_index = self.app.timeline.num_video_tracks
- audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None
- else:
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
- clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec]
- for clip in clips_to_remove:
- if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip)
- video_track_index, audio_track_index = 1, 1 if has_audio else None
- group_id = str(uuid.uuid4())
- new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id)
- self.app.timeline.add_clip(new_clip)
- if audio_track_index:
- if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index
- audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id)
- self.app.timeline.add_clip(audio_clip)
- try:
- self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action)
- self.app.prune_empty_tracks()
- self.app.status_label.setText("AI clip inserted successfully.")
- for i in range(self.client_widget.results_list.count()):
- widget = self.client_widget.results_list.itemWidget(self.client_widget.results_list.item(i))
- if widget and widget.video_path == video_path:
- self.client_widget.results_list.takeItem(i); break
- except Exception as e:
- import traceback; traceback.print_exc()
- self.app.status_label.setText(f"WanGP Error during clip insertion: {e}")
\ No newline at end of file
+import sys
+import os
+import threading
+import time
+import json
+import tempfile
+import shutil
+import uuid
+from unittest.mock import MagicMock
+from pathlib import Path
+
+# --- Start of Gradio Hijacking ---
+# This block creates a mock Gradio module. When wgp.py is imported,
+# all calls to `gr.*` will be intercepted by these mock objects,
+# preventing any UI from being built and allowing us to use the
+# backend logic directly.
+
+class MockGradioComponent(MagicMock):
+ """A smarter mock that captures constructor arguments."""
+ def __init__(self, *args, **kwargs):
+ super().__init__(name=f"gr.{kwargs.get('elem_id', 'component')}")
+ self.kwargs = kwargs
+ self.value = kwargs.get('value')
+ self.choices = kwargs.get('choices')
+
+ for method in ['then', 'change', 'click', 'input', 'select', 'upload', 'mount', 'launch', 'on', 'release']:
+ setattr(self, method, lambda *a, **kw: self)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+class MockGradioError(Exception):
+ pass
+
+class MockGradioModule:
+ def __getattr__(self, name):
+ if name == 'Error':
+ return lambda *args, **kwargs: MockGradioError(*args)
+
+ if name in ['Info', 'Warning']:
+ return lambda *args, **kwargs: print(f"Intercepted gr.{name}:", *args)
+
+ return lambda *args, **kwargs: MockGradioComponent(*args, **kwargs)
+
+sys.modules['gradio'] = MockGradioModule()
+sys.modules['gradio.gallery'] = MockGradioModule()
+sys.modules['shared.gradio.gallery'] = MockGradioModule()
+# --- End of Gradio Hijacking ---
+
+import ffmpeg
+
+from PyQt6.QtWidgets import (
+ QWidget, QVBoxLayout, QHBoxLayout, QTabWidget,
+ QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox,
+ QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem,
+ QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem,
+ QMessageBox, QRadioButton, QSizePolicy
+)
+from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QSize, QRectF
+from PyQt6.QtGui import QPixmap, QImage, QDropEvent
+from PyQt6.QtMultimedia import QMediaPlayer
+from PyQt6.QtMultimediaWidgets import QVideoWidget
+from PIL.ImageQt import ImageQt
+
+# Import the base plugin class from the main application's path
+sys.path.append(str(Path(__file__).parent.parent.parent))
+from plugins import VideoEditorPlugin
+
+class VideoResultItemWidget(QWidget):
+ """A widget to display a generated video with a hover-to-play preview and insert button."""
+ def __init__(self, video_path, plugin, parent=None):
+ super().__init__(parent)
+ self.video_path = video_path
+ self.plugin = plugin
+ self.app = plugin.app
+ self.duration = 0.0
+ self.has_audio = False
+
+ self.setMinimumSize(200, 180)
+ self.setMaximumHeight(190)
+
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(5, 5, 5, 5)
+ layout.setSpacing(5)
+
+ self.media_player = QMediaPlayer()
+ self.video_widget = QVideoWidget()
+ self.video_widget.setFixedSize(160, 90)
+ self.media_player.setVideoOutput(self.video_widget)
+ self.media_player.setSource(QUrl.fromLocalFile(self.video_path))
+ self.media_player.setLoops(QMediaPlayer.Loops.Infinite)
+
+ self.info_label = QLabel(os.path.basename(video_path))
+ self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.info_label.setWordWrap(True)
+
+ self.insert_button = QPushButton("Insert into Timeline")
+ self.insert_button.clicked.connect(self.on_insert)
+
+ h_layout = QHBoxLayout()
+ h_layout.addStretch()
+ h_layout.addWidget(self.video_widget)
+ h_layout.addStretch()
+
+ layout.addLayout(h_layout)
+ layout.addWidget(self.info_label)
+ layout.addWidget(self.insert_button)
+ self.probe_video()
+
+ def probe_video(self):
+ try:
+ probe = ffmpeg.probe(self.video_path)
+ self.duration = float(probe['format']['duration'])
+ self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', []))
+ self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)")
+ except Exception as e:
+ self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}")
+ print(f"Error probing video {self.video_path}: {e}")
+
+ def enterEvent(self, event):
+ super().enterEvent(event)
+ self.media_player.play()
+ if not self.plugin.active_region or self.duration == 0: return
+ start_sec, _ = self.plugin.active_region
+ timeline = self.app.timeline_widget
+ video_rect, audio_rect = None, None
+ x = timeline.sec_to_x(start_sec)
+ w = int(self.duration * timeline.pixels_per_second)
+ if self.plugin.insert_on_new_track:
+ video_y = timeline.TIMESCALE_HEIGHT
+ video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
+ if self.has_audio:
+ audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT
+ audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
+ else:
+ v_track_idx = 1
+ visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx
+ video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT
+ video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
+ if self.has_audio:
+ a_track_idx = 1
+ visual_a_idx = a_track_idx - 1
+ audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT
+ audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT)
+ timeline.set_hover_preview_rects(video_rect, audio_rect)
+
+ def leaveEvent(self, event):
+ super().leaveEvent(event)
+ self.media_player.pause()
+ self.media_player.setPosition(0)
+ self.app.timeline_widget.set_hover_preview_rects(None, None)
+
+ def on_insert(self):
+ self.media_player.stop()
+ self.media_player.setSource(QUrl())
+ self.media_player.setVideoOutput(None)
+ self.app.timeline_widget.set_hover_preview_rects(None, None)
+ self.plugin.insert_generated_clip(self.video_path)
+
+
+class QueueTableWidget(QTableWidget):
+ rowsMoved = pyqtSignal(int, int)
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.setDragEnabled(True)
+ self.setAcceptDrops(True)
+ self.setDropIndicatorShown(True)
+ self.setDragDropMode(self.DragDropMode.InternalMove)
+ self.setSelectionBehavior(self.SelectionBehavior.SelectRows)
+ self.setSelectionMode(self.SelectionMode.SingleSelection)
+
+ def dropEvent(self, event: QDropEvent):
+ if event.source() == self and event.dropAction() == Qt.DropAction.MoveAction:
+ source_row = self.currentRow()
+ target_item = self.itemAt(event.position().toPoint())
+ dest_row = target_item.row() if target_item else self.rowCount()
+ if source_row < dest_row: dest_row -=1
+ if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row)
+ event.acceptProposedAction()
+ else:
+ super().dropEvent(event)
+
+class HoverVideoPreview(QWidget):
+ def __init__(self, player, video_widget, parent=None):
+ super().__init__(parent)
+ self.player = player
+ layout = QVBoxLayout(self)
+ layout.setContentsMargins(0,0,0,0)
+ layout.addWidget(video_widget)
+ self.setFixedSize(160, 90)
+
+ def enterEvent(self, event):
+ super().enterEvent(event)
+ if self.player.source().isValid():
+ self.player.play()
+
+ def leaveEvent(self, event):
+ super().leaveEvent(event)
+ self.player.pause()
+ self.player.setPosition(0)
+
+class Worker(QObject):
+ progress = pyqtSignal(list)
+ status = pyqtSignal(str)
+ preview = pyqtSignal(object)
+ output = pyqtSignal()
+ finished = pyqtSignal()
+ error = pyqtSignal(str)
+ def __init__(self, plugin, state):
+ super().__init__()
+ self.plugin = plugin
+ self.wgp = plugin.wgp
+ self.state = state
+ self._is_running = True
+ self._last_progress_phase = None
+ self._last_preview = None
+
+ def run(self):
+ def generation_target():
+ try:
+ for _ in self.wgp.process_tasks(self.state):
+ if self._is_running: self.output.emit()
+ else: break
+ except Exception as e:
+ import traceback
+ print("Error in generation thread:")
+ traceback.print_exc()
+ if "gradio.Error" in str(type(e)): self.error.emit(str(e))
+ else: self.error.emit(f"An unexpected error occurred: {e}")
+ finally:
+ self._is_running = False
+ gen_thread = threading.Thread(target=generation_target, daemon=True)
+ gen_thread.start()
+ while self._is_running:
+ gen = self.state.get('gen', {})
+ current_phase = gen.get("progress_phase")
+ if current_phase and current_phase != self._last_progress_phase:
+ self._last_progress_phase = current_phase
+ phase_name, step = current_phase
+ total_steps = gen.get("num_inference_steps", 1)
+ high_level_status = gen.get("progress_status", "")
+ status_msg = self.wgp.merge_status_context(high_level_status, phase_name)
+ progress_args = [(step, total_steps), status_msg]
+ self.progress.emit(progress_args)
+ preview_img = gen.get('preview')
+ if preview_img is not None and preview_img is not self._last_preview:
+ self._last_preview = preview_img
+ self.preview.emit(preview_img)
+ gen['preview'] = None
+ time.sleep(0.1)
+ gen_thread.join()
+ self.finished.emit()
+
+
+class WgpDesktopPluginWidget(QWidget):
+ def __init__(self, plugin):
+ super().__init__()
+ self.plugin = plugin
+ self.wgp = plugin.wgp
+ self.widgets = {}
+ self.state = {}
+ self.worker = None
+ self.thread = None
+ self.lora_map = {}
+ self.full_resolution_choices = []
+ self.main_config = {}
+ self.processed_files = set()
+
+ self.load_main_config()
+ self.setup_ui()
+ self.apply_initial_config()
+ self.connect_signals()
+ self.init_wgp_state()
+
+ def setup_ui(self):
+ main_layout = QVBoxLayout(self)
+ self.header_info = QLabel("Header Info")
+ main_layout.addWidget(self.header_info)
+ self.tabs = QTabWidget()
+ main_layout.addWidget(self.tabs)
+ self.setup_generator_tab()
+ self.setup_config_tab()
+
+ def create_widget(self, widget_class, name, *args, **kwargs):
+ widget = widget_class(*args, **kwargs)
+ self.widgets[name] = widget
+ return widget
+
+ def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1.0, precision=1):
+ container = QWidget()
+ hbox = QHBoxLayout(container)
+ hbox.setContentsMargins(0, 0, 0, 0)
+ slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal)
+ slider.setRange(min_val, max_val)
+ slider.setValue(int(initial_val * scale))
+ value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}")
+ value_label.setMinimumWidth(50)
+ slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}"))
+ hbox.addWidget(slider)
+ hbox.addWidget(value_label)
+ return container
+
+ def _create_file_input(self, name, label_text):
+ container = self.create_widget(QWidget, f"{name}_container")
+ vbox = QVBoxLayout(container)
+ vbox.setContentsMargins(0, 0, 0, 0)
+ vbox.setSpacing(5)
+
+ input_widget = QWidget()
+ hbox = QHBoxLayout(input_widget)
+ hbox.setContentsMargins(0, 0, 0, 0)
+
+ line_edit = self.create_widget(QLineEdit, name)
+ line_edit.setPlaceholderText("No file selected or path pasted")
+ button = QPushButton("Browse...")
+
+ def open_dialog():
+ if "refs" in name:
+ filenames, _ = QFileDialog.getOpenFileNames(self, f"Select {label_text}", filter="Image Files (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)")
+ if filenames: line_edit.setText(";".join(filenames))
+ else:
+ filter_str = "All Files (*)"
+ if 'video' in name:
+ filter_str = "Video Files (*.mp4 *.mkv *.mov *.avi);;All Files (*)"
+ elif 'image' in name:
+ filter_str = "Image Files (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)"
+ elif 'audio' in name:
+ filter_str = "Audio Files (*.wav *.mp3 *.flac);;All Files (*)"
+
+ filename, _ = QFileDialog.getOpenFileName(self, f"Select {label_text}", filter=filter_str)
+ if filename: line_edit.setText(filename)
+
+ button.clicked.connect(open_dialog)
+ clear_button = QPushButton("X")
+ clear_button.setFixedWidth(30)
+ clear_button.clicked.connect(lambda: line_edit.clear())
+
+ hbox.addWidget(QLabel(f"{label_text}:"))
+ hbox.addWidget(line_edit, 1)
+ hbox.addWidget(button)
+ hbox.addWidget(clear_button)
+ vbox.addWidget(input_widget)
+
+ # Preview Widget
+ preview_container = self.create_widget(QWidget, f"{name}_preview_container")
+ preview_hbox = QHBoxLayout(preview_container)
+ preview_hbox.setContentsMargins(0, 0, 0, 0)
+ preview_hbox.addStretch()
+
+ is_image_input = 'image' in name and 'audio' not in name
+ is_video_input = 'video' in name and 'audio' not in name
+
+ if is_image_input:
+ preview_widget = self.create_widget(QLabel, f"{name}_preview")
+ preview_widget.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ preview_widget.setFixedSize(160, 90)
+ preview_widget.setStyleSheet("border: 1px solid #cccccc; background-color: #f0f0f0;")
+ preview_widget.setText("Image Preview")
+ preview_hbox.addWidget(preview_widget)
+ elif is_video_input:
+ media_player = QMediaPlayer()
+ video_widget = QVideoWidget()
+ video_widget.setFixedSize(160, 90)
+ media_player.setVideoOutput(video_widget)
+ media_player.setLoops(QMediaPlayer.Loops.Infinite)
+
+ self.widgets[f"{name}_player"] = media_player
+
+ preview_widget = HoverVideoPreview(media_player, video_widget)
+ preview_hbox.addWidget(preview_widget)
+ else:
+ preview_widget = self.create_widget(QLabel, f"{name}_preview")
+ preview_widget.setText("No preview available")
+ preview_hbox.addWidget(preview_widget)
+
+ preview_hbox.addStretch()
+ vbox.addWidget(preview_container)
+ preview_container.setVisible(False)
+
+ def update_preview(path):
+ container = self.widgets.get(f"{name}_preview_container")
+ if not container: return
+
+ first_path = path.split(';')[0] if path else ''
+
+ if not first_path or not os.path.exists(first_path):
+ container.setVisible(False)
+ if is_video_input:
+ player = self.widgets.get(f"{name}_player")
+ if player: player.setSource(QUrl())
+ return
+
+ container.setVisible(True)
+
+ if is_image_input:
+ preview_label = self.widgets.get(f"{name}_preview")
+ if preview_label:
+ pixmap = QPixmap(first_path)
+ if not pixmap.isNull():
+ preview_label.setPixmap(pixmap.scaled(preview_label.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation))
+ else:
+ preview_label.setText("Invalid Image")
+
+ elif is_video_input:
+ player = self.widgets.get(f"{name}_player")
+ if player:
+ player.setSource(QUrl.fromLocalFile(first_path))
+
+ else: # Audio or other
+ preview_label = self.widgets.get(f"{name}_preview")
+ if preview_label:
+ preview_label.setText(os.path.basename(path))
+
+ line_edit.textChanged.connect(update_preview)
+
+ return container
+
+ def setup_generator_tab(self):
+ gen_tab = QWidget()
+ self.tabs.addTab(gen_tab, "Video Generator")
+ gen_layout = QHBoxLayout(gen_tab)
+ left_panel = QWidget()
+ left_layout = QVBoxLayout(left_panel)
+ gen_layout.addWidget(left_panel, 1)
+ right_panel = QWidget()
+ right_layout = QVBoxLayout(right_panel)
+ gen_layout.addWidget(right_panel, 1)
+ scroll_area = QScrollArea()
+ scroll_area.setWidgetResizable(True)
+ left_layout.addWidget(scroll_area)
+ options_widget = QWidget()
+ scroll_area.setWidget(options_widget)
+ options_layout = QVBoxLayout(options_widget)
+ model_layout = QHBoxLayout()
+ self.widgets['model_family'] = QComboBox()
+ self.widgets['model_base_type_choice'] = QComboBox()
+ self.widgets['model_choice'] = QComboBox()
+ model_layout.addWidget(QLabel("Model:"))
+ model_layout.addWidget(self.widgets['model_family'], 2)
+ model_layout.addWidget(self.widgets['model_base_type_choice'], 3)
+ model_layout.addWidget(self.widgets['model_choice'], 3)
+ options_layout.addLayout(model_layout)
+ options_layout.addWidget(QLabel("Prompt:"))
+ self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100)
+ options_layout.addWidget(self.widgets['prompt'])
+ options_layout.addWidget(QLabel("Negative Prompt:"))
+ self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60)
+ options_layout.addWidget(self.widgets['negative_prompt'])
+ basic_group = QGroupBox("Basic Options")
+ basic_layout = QFormLayout(basic_group)
+ res_container = QWidget()
+ res_hbox = QHBoxLayout(res_container)
+ res_hbox.setContentsMargins(0, 0, 0, 0)
+ res_hbox.addWidget(self.create_widget(QComboBox, 'resolution_group'), 2)
+ res_hbox.addWidget(self.create_widget(QComboBox, 'resolution'), 3)
+ basic_layout.addRow("Resolution:", res_container)
+ basic_layout.addRow("Video Length:", self._create_slider_with_label('video_length', 1, 737, 81, 1.0, 0))
+ basic_layout.addRow("Inference Steps:", self._create_slider_with_label('num_inference_steps', 1, 100, 30, 1.0, 0))
+ basic_layout.addRow("Seed:", self.create_widget(QLineEdit, 'seed', '-1'))
+ options_layout.addWidget(basic_group)
+ mode_options_group = QGroupBox("Generation Mode & Input Options")
+ mode_options_layout = QVBoxLayout(mode_options_group)
+ mode_hbox = QHBoxLayout()
+ mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_t', "Text Prompt Only"))
+ mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_s', "Start with Image"))
+ mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_v', "Continue Video"))
+ mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_l', "Continue Last Video"))
+ self.widgets['mode_t'].setChecked(True)
+ mode_options_layout.addLayout(mode_hbox)
+ options_hbox = QHBoxLayout()
+ options_hbox.addWidget(self.create_widget(QCheckBox, 'image_end_checkbox', "Use End Image"))
+ options_hbox.addWidget(self.create_widget(QCheckBox, 'control_video_checkbox', "Use Control Video"))
+ options_hbox.addWidget(self.create_widget(QCheckBox, 'ref_image_checkbox', "Use Reference Image(s)"))
+ mode_options_layout.addLayout(options_hbox)
+ options_layout.addWidget(mode_options_group)
+ inputs_group = QGroupBox("Inputs")
+ inputs_layout = QVBoxLayout(inputs_group)
+ inputs_layout.addWidget(self._create_file_input('image_start', "Start Image"))
+ inputs_layout.addWidget(self._create_file_input('image_end', "End Image"))
+ inputs_layout.addWidget(self._create_file_input('video_source', "Source Video"))
+ inputs_layout.addWidget(self._create_file_input('video_guide', "Control Video"))
+ inputs_layout.addWidget(self._create_file_input('video_mask', "Video Mask"))
+ inputs_layout.addWidget(self._create_file_input('image_refs', "Reference Image(s)"))
+ denoising_row = QFormLayout()
+ denoising_row.addRow("Denoising Strength:", self._create_slider_with_label('denoising_strength', 0, 100, 50, 100.0, 2))
+ inputs_layout.addLayout(denoising_row)
+ options_layout.addWidget(inputs_group)
+ self.advanced_group = self.create_widget(QGroupBox, 'advanced_group', "Advanced Options")
+ self.advanced_group.setCheckable(True)
+ self.advanced_group.setChecked(False)
+ advanced_layout = QVBoxLayout(self.advanced_group)
+ advanced_tabs = self.create_widget(QTabWidget, 'advanced_tabs')
+ advanced_layout.addWidget(advanced_tabs)
+ self._setup_adv_tab_general(advanced_tabs)
+ self._setup_adv_tab_loras(advanced_tabs)
+ self._setup_adv_tab_speed(advanced_tabs)
+ self._setup_adv_tab_postproc(advanced_tabs)
+ self._setup_adv_tab_audio(advanced_tabs)
+ self._setup_adv_tab_quality(advanced_tabs)
+ self._setup_adv_tab_sliding_window(advanced_tabs)
+ self._setup_adv_tab_misc(advanced_tabs)
+ options_layout.addWidget(self.advanced_group)
+
+ btn_layout = QHBoxLayout()
+ self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate")
+ self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue")
+ self.generate_btn.setEnabled(True)
+ self.add_to_queue_btn.setEnabled(False)
+ btn_layout.addWidget(self.generate_btn)
+ btn_layout.addWidget(self.add_to_queue_btn)
+ right_layout.addLayout(btn_layout)
+ self.status_label = self.create_widget(QLabel, 'status_label', "Idle")
+ right_layout.addWidget(self.status_label)
+ self.progress_bar = self.create_widget(QProgressBar, 'progress_bar')
+ right_layout.addWidget(self.progress_bar)
+ preview_group = self.create_widget(QGroupBox, 'preview_group', "Preview")
+ preview_group.setCheckable(True)
+ preview_group.setStyleSheet("QGroupBox { border: 1px solid #cccccc; }")
+ preview_group_layout = QVBoxLayout(preview_group)
+ self.preview_image = self.create_widget(QLabel, 'preview_image', "")
+ self.preview_image.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.preview_image.setMinimumSize(200, 200)
+ preview_group_layout.addWidget(self.preview_image)
+ right_layout.addWidget(preview_group)
+
+ results_group = QGroupBox("Generated Clips")
+ results_group.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding)
+ results_layout = QVBoxLayout(results_group)
+ self.results_list = self.create_widget(QListWidget, 'results_list')
+ self.results_list.setFlow(QListWidget.Flow.LeftToRight)
+ self.results_list.setWrapping(True)
+ self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust)
+ self.results_list.setSpacing(10)
+ results_layout.addWidget(self.results_list)
+ right_layout.addWidget(results_group)
+
+ right_layout.addWidget(QLabel("Queue:"))
+ self.queue_table = self.create_widget(QueueTableWidget, 'queue_table')
+ right_layout.addWidget(self.queue_table)
+ queue_btn_layout = QHBoxLayout()
+ self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected")
+ self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue")
+ self.abort_btn = self.create_widget(QPushButton, 'abort_btn', "Abort")
+ queue_btn_layout.addWidget(self.remove_queue_btn)
+ queue_btn_layout.addWidget(self.clear_queue_btn)
+ queue_btn_layout.addWidget(self.abort_btn)
+ right_layout.addLayout(queue_btn_layout)
+
+ def _setup_adv_tab_general(self, tabs):
+ tab = QWidget()
+ tabs.addTab(tab, "General")
+ layout = QFormLayout(tab)
+ self.widgets['adv_general_layout'] = layout
+ guidance_group = QGroupBox("Guidance")
+ guidance_layout = self.create_widget(QFormLayout, 'guidance_layout', guidance_group)
+ guidance_layout.addRow("Guidance (CFG):", self._create_slider_with_label('guidance_scale', 10, 200, 5.0, 10.0, 1))
+ self.widgets['guidance_phases_row_index'] = guidance_layout.rowCount()
+ guidance_layout.addRow("Guidance Phases:", self.create_widget(QComboBox, 'guidance_phases'))
+ self.widgets['guidance2_row_index'] = guidance_layout.rowCount()
+ guidance_layout.addRow("Guidance 2:", self._create_slider_with_label('guidance2_scale', 10, 200, 5.0, 10.0, 1))
+ self.widgets['guidance3_row_index'] = guidance_layout.rowCount()
+ guidance_layout.addRow("Guidance 3:", self._create_slider_with_label('guidance3_scale', 10, 200, 5.0, 10.0, 1))
+ self.widgets['switch_thresh_row_index'] = guidance_layout.rowCount()
+ guidance_layout.addRow("Switch Threshold:", self._create_slider_with_label('switch_threshold', 0, 1000, 0, 1.0, 0))
+ layout.addRow(guidance_group)
+ nag_group = self.create_widget(QGroupBox, 'nag_group', "NAG (Negative Adversarial Guidance)")
+ nag_layout = QFormLayout(nag_group)
+ nag_layout.addRow("NAG Scale:", self._create_slider_with_label('NAG_scale', 10, 200, 1.0, 10.0, 1))
+ nag_layout.addRow("NAG Tau:", self._create_slider_with_label('NAG_tau', 10, 50, 3.5, 10.0, 1))
+ nag_layout.addRow("NAG Alpha:", self._create_slider_with_label('NAG_alpha', 0, 20, 0.5, 10.0, 1))
+ layout.addRow(nag_group)
+ self.widgets['solver_row_container'] = QWidget()
+ solver_hbox = QHBoxLayout(self.widgets['solver_row_container'])
+ solver_hbox.setContentsMargins(0,0,0,0)
+ solver_hbox.addWidget(QLabel("Sampler Solver:"))
+ solver_hbox.addWidget(self.create_widget(QComboBox, 'sample_solver'))
+ layout.addRow(self.widgets['solver_row_container'])
+ self.widgets['flow_shift_row_index'] = layout.rowCount()
+ layout.addRow("Shift Scale:", self._create_slider_with_label('flow_shift', 10, 250, 3.0, 10.0, 1))
+ self.widgets['audio_guidance_row_index'] = layout.rowCount()
+ layout.addRow("Audio Guidance:", self._create_slider_with_label('audio_guidance_scale', 10, 200, 4.0, 10.0, 1))
+ self.widgets['repeat_generation_row_index'] = layout.rowCount()
+ layout.addRow("Repeat Generations:", self._create_slider_with_label('repeat_generation', 1, 25, 1, 1.0, 0))
+ combo = self.create_widget(QComboBox, 'multi_images_gen_type')
+ combo.addItem("Generate all combinations", 0)
+ combo.addItem("Match images and texts", 1)
+ self.widgets['multi_images_gen_type_row_index'] = layout.rowCount()
+ layout.addRow("Multi-Image Mode:", combo)
+
+ def _setup_adv_tab_loras(self, tabs):
+ tab = QWidget()
+ tabs.addTab(tab, "Loras")
+ layout = QVBoxLayout(tab)
+ layout.addWidget(QLabel("Available Loras (Ctrl+Click to select multiple):"))
+ lora_list = self.create_widget(QListWidget, 'activated_loras')
+ lora_list.setSelectionMode(QListWidget.SelectionMode.MultiSelection)
+ layout.addWidget(lora_list)
+ layout.addWidget(QLabel("Loras Multipliers:"))
+ layout.addWidget(self.create_widget(QTextEdit, 'loras_multipliers'))
+
+ def _setup_adv_tab_speed(self, tabs):
+ tab = QWidget()
+ tabs.addTab(tab, "Speed")
+ layout = QFormLayout(tab)
+ combo = self.create_widget(QComboBox, 'skip_steps_cache_type')
+ combo.addItem("None", "")
+ combo.addItem("Tea Cache", "tea")
+ combo.addItem("Mag Cache", "mag")
+ layout.addRow("Cache Type:", combo)
+ combo = self.create_widget(QComboBox, 'skip_steps_multiplier')
+ combo.addItem("x1.5 speed up", 1.5)
+ combo.addItem("x1.75 speed up", 1.75)
+ combo.addItem("x2.0 speed up", 2.0)
+ combo.addItem("x2.25 speed up", 2.25)
+ combo.addItem("x2.5 speed up", 2.5)
+ layout.addRow("Acceleration:", combo)
+ layout.addRow("Start %:", self._create_slider_with_label('skip_steps_start_step_perc', 0, 100, 0, 1.0, 0))
+
+ def _setup_adv_tab_postproc(self, tabs):
+ tab = QWidget()
+ tabs.addTab(tab, "Post-Processing")
+ layout = QFormLayout(tab)
+ combo = self.create_widget(QComboBox, 'temporal_upsampling')
+ combo.addItem("Disabled", "")
+ combo.addItem("Rife x2 frames/s", "rife2")
+ combo.addItem("Rife x4 frames/s", "rife4")
+ layout.addRow("Temporal Upsampling:", combo)
+ combo = self.create_widget(QComboBox, 'spatial_upsampling')
+ combo.addItem("Disabled", "")
+ combo.addItem("Lanczos x1.5", "lanczos1.5")
+ combo.addItem("Lanczos x2.0", "lanczos2")
+ layout.addRow("Spatial Upsampling:", combo)
+ layout.addRow("Film Grain Intensity:", self._create_slider_with_label('film_grain_intensity', 0, 100, 0, 100.0, 2))
+ layout.addRow("Film Grain Saturation:", self._create_slider_with_label('film_grain_saturation', 0, 100, 0.5, 100.0, 2))
+
+ def _setup_adv_tab_audio(self, tabs):
+ tab = QWidget()
+ tabs.addTab(tab, "Audio")
+ layout = QFormLayout(tab)
+ combo = self.create_widget(QComboBox, 'MMAudio_setting')
+ combo.addItem("Disabled", 0)
+ combo.addItem("Enabled", 1)
+ layout.addRow("MMAudio:", combo)
+ layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_prompt', placeholderText="MMAudio Prompt"))
+ layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_neg_prompt', placeholderText="MMAudio Negative Prompt"))
+ layout.addRow(self._create_file_input('audio_source', "Custom Soundtrack"))
+
+ def _setup_adv_tab_quality(self, tabs):
+ tab = QWidget()
+ tabs.addTab(tab, "Quality")
+ layout = QVBoxLayout(tab)
+ slg_group = self.create_widget(QGroupBox, 'slg_group', "Skip Layer Guidance")
+ slg_layout = QFormLayout(slg_group)
+ slg_combo = self.create_widget(QComboBox, 'slg_switch')
+ slg_combo.addItem("OFF", 0)
+ slg_combo.addItem("ON", 1)
+ slg_layout.addRow("Enable SLG:", slg_combo)
+ slg_layout.addRow("Start %:", self._create_slider_with_label('slg_start_perc', 0, 100, 10, 1.0, 0))
+ slg_layout.addRow("End %:", self._create_slider_with_label('slg_end_perc', 0, 100, 90, 1.0, 0))
+ layout.addWidget(slg_group)
+ quality_form = QFormLayout()
+ self.widgets['quality_form_layout'] = quality_form
+ apg_combo = self.create_widget(QComboBox, 'apg_switch')
+ apg_combo.addItem("OFF", 0)
+ apg_combo.addItem("ON", 1)
+ self.widgets['apg_switch_row_index'] = quality_form.rowCount()
+ quality_form.addRow("Adaptive Projected Guidance:", apg_combo)
+ cfg_star_combo = self.create_widget(QComboBox, 'cfg_star_switch')
+ cfg_star_combo.addItem("OFF", 0)
+ cfg_star_combo.addItem("ON", 1)
+ self.widgets['cfg_star_switch_row_index'] = quality_form.rowCount()
+ quality_form.addRow("Classifier-Free Guidance Star:", cfg_star_combo)
+ self.widgets['cfg_zero_step_row_index'] = quality_form.rowCount()
+ quality_form.addRow("CFG Zero below Layer:", self._create_slider_with_label('cfg_zero_step', -1, 39, -1, 1.0, 0))
+ combo = self.create_widget(QComboBox, 'min_frames_if_references')
+ combo.addItem("Disabled (1 frame)", 1)
+ combo.addItem("Generate 5 frames", 5)
+ combo.addItem("Generate 9 frames", 9)
+ combo.addItem("Generate 13 frames", 13)
+ combo.addItem("Generate 17 frames", 17)
+ self.widgets['min_frames_if_references_row_index'] = quality_form.rowCount()
+ quality_form.addRow("Min Frames for Quality:", combo)
+ layout.addLayout(quality_form)
+
+ def _setup_adv_tab_sliding_window(self, tabs):
+ tab = QWidget()
+ self.widgets['sliding_window_tab_index'] = tabs.count()
+ tabs.addTab(tab, "Sliding Window")
+ layout = QFormLayout(tab)
+ layout.addRow("Window Size:", self._create_slider_with_label('sliding_window_size', 5, 257, 129, 1.0, 0))
+ layout.addRow("Overlap:", self._create_slider_with_label('sliding_window_overlap', 1, 97, 5, 1.0, 0))
+ layout.addRow("Color Correction:", self._create_slider_with_label('sliding_window_color_correction_strength', 0, 100, 0, 100.0, 2))
+ layout.addRow("Overlap Noise:", self._create_slider_with_label('sliding_window_overlap_noise', 0, 150, 20, 1.0, 0))
+ layout.addRow("Discard Last Frames:", self._create_slider_with_label('sliding_window_discard_last_frames', 0, 20, 0, 1.0, 0))
+
+ def _setup_adv_tab_misc(self, tabs):
+ tab = QWidget()
+ tabs.addTab(tab, "Misc")
+ layout = QFormLayout(tab)
+ self.widgets['misc_layout'] = layout
+ riflex_combo = self.create_widget(QComboBox, 'RIFLEx_setting')
+ riflex_combo.addItem("Auto", 0)
+ riflex_combo.addItem("Always ON", 1)
+ riflex_combo.addItem("Always OFF", 2)
+ self.widgets['riflex_row_index'] = layout.rowCount()
+ layout.addRow("RIFLEx Setting:", riflex_combo)
+ fps_combo = self.create_widget(QComboBox, 'force_fps')
+ layout.addRow("Force FPS:", fps_combo)
+ profile_combo = self.create_widget(QComboBox, 'override_profile')
+ profile_combo.addItem("Default Profile", -1)
+ for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val)
+ layout.addRow("Override Memory Profile:", profile_combo)
+ combo = self.create_widget(QComboBox, 'multi_prompts_gen_type')
+ combo.addItem("Generate new Video per line", 0)
+ combo.addItem("Use line for new Sliding Window", 1)
+ layout.addRow("Multi-Prompt Mode:", combo)
+
+ def setup_config_tab(self):
+ config_tab = QWidget()
+ self.tabs.addTab(config_tab, "Configuration")
+ main_layout = QVBoxLayout(config_tab)
+ self.config_status_label = QLabel("Apply changes for them to take effect. Some may require a restart.")
+ main_layout.addWidget(self.config_status_label)
+ config_tabs = QTabWidget()
+ main_layout.addWidget(config_tabs)
+ config_tabs.addTab(self._create_general_config_tab(), "General")
+ config_tabs.addTab(self._create_performance_config_tab(), "Performance")
+ config_tabs.addTab(self._create_extensions_config_tab(), "Extensions")
+ config_tabs.addTab(self._create_outputs_config_tab(), "Outputs")
+ config_tabs.addTab(self._create_notifications_config_tab(), "Notifications")
+ self.apply_config_btn = QPushButton("Apply Changes")
+ self.apply_config_btn.clicked.connect(self._on_apply_config_changes)
+ main_layout.addWidget(self.apply_config_btn)
+
+ def _create_scrollable_form_tab(self):
+ tab_widget = QWidget()
+ scroll_area = QScrollArea()
+ scroll_area.setWidgetResizable(True)
+ layout = QVBoxLayout(tab_widget)
+ layout.addWidget(scroll_area)
+ content_widget = QWidget()
+ form_layout = QFormLayout(content_widget)
+ scroll_area.setWidget(content_widget)
+ return tab_widget, form_layout
+
+ def _create_config_combo(self, form_layout, label, key, choices, default_value):
+ combo = QComboBox()
+ for text, data in choices: combo.addItem(text, data)
+ index = combo.findData(self.wgp.server_config.get(key, default_value))
+ if index != -1: combo.setCurrentIndex(index)
+ self.widgets[f'config_{key}'] = combo
+ form_layout.addRow(label, combo)
+
+ def _create_config_slider(self, form_layout, label, key, min_val, max_val, default_value, step=1):
+ container = QWidget()
+ hbox = QHBoxLayout(container)
+ hbox.setContentsMargins(0,0,0,0)
+ slider = QSlider(Qt.Orientation.Horizontal)
+ slider.setRange(min_val, max_val)
+ slider.setSingleStep(step)
+ slider.setValue(self.wgp.server_config.get(key, default_value))
+ value_label = QLabel(str(slider.value()))
+ value_label.setMinimumWidth(40)
+ slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v)))
+ hbox.addWidget(slider)
+ hbox.addWidget(value_label)
+ self.widgets[f'config_{key}'] = slider
+ form_layout.addRow(label, container)
+
+ def _create_config_checklist(self, form_layout, label, key, choices, default_value):
+ list_widget = QListWidget()
+ list_widget.setMinimumHeight(100)
+ current_values = self.wgp.server_config.get(key, default_value)
+ for text, data in choices:
+ item = QListWidgetItem(text)
+ item.setData(Qt.ItemDataRole.UserRole, data)
+ item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable)
+ item.setCheckState(Qt.CheckState.Checked if data in current_values else Qt.CheckState.Unchecked)
+ list_widget.addItem(item)
+ self.widgets[f'config_{key}'] = list_widget
+ form_layout.addRow(label, list_widget)
+
+ def _create_config_textbox(self, form_layout, label, key, default_value, multi_line=False):
+ if multi_line:
+ textbox = QTextEdit(default_value)
+ textbox.setAcceptRichText(False)
+ else:
+ textbox = QLineEdit(default_value)
+ self.widgets[f'config_{key}'] = textbox
+ form_layout.addRow(label, textbox)
+
+ def _create_general_config_tab(self):
+ tab, form = self._create_scrollable_form_tab()
+ _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False)
+ self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types)
+ self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1)
+ self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0)
+ self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto")
+ self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata")
+ self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], [])
+ self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5)
+ self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0)
+ self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1)
+ checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths))
+ checkpoints_textbox = QTextEdit()
+ checkpoints_textbox.setPlainText(checkpoints_paths_text)
+ checkpoints_textbox.setAcceptRichText(False)
+ checkpoints_textbox.setMinimumHeight(60)
+ self.widgets['config_checkpoints_paths'] = checkpoints_textbox
+ form.addRow("Checkpoints Paths:", checkpoints_textbox)
+ self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default")
+ return tab
+
+ def _create_performance_config_tab(self):
+ tab, form = self._create_scrollable_form_tab()
+ self._create_config_combo(form, "Transformer Quantization:", "transformer_quantization", [("Scaled Int8 (recommended)", "int8"), ("16-bit (no quantization)", "bf16")], "int8")
+ self._create_config_combo(form, "Transformer Data Type:", "transformer_dtype_policy", [("Best Supported by Hardware", ""), ("FP16", "fp16"), ("BF16", "bf16")], "")
+ self._create_config_combo(form, "Transformer Calculation:", "mixed_precision", [("16-bit only", "0"), ("Mixed 16/32-bit (better quality)", "1")], "0")
+ self._create_config_combo(form, "Text Encoder:", "text_encoder_quantization", [("16-bit (more RAM, better quality)", "bf16"), ("8-bit (less RAM)", "int8")], "int8")
+ self._create_config_combo(form, "VAE Precision:", "vae_precision", [("16-bit (faster, less VRAM)", "16"), ("32-bit (slower, better quality)", "32")], "16")
+ self._create_config_combo(form, "Compile Transformer:", "compile", [("On (requires Triton)", "transformer"), ("Off", "")], "")
+ self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl")
+ self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0)
+ self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1)
+ self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM)
+ self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100)
+ release_ram_btn = QPushButton("Force Release Models from RAM")
+ release_ram_btn.clicked.connect(self._on_release_ram)
+ form.addRow(release_ram_btn)
+ return tab
+
+ def _create_extensions_config_tab(self):
+ tab, form = self._create_scrollable_form_tab()
+ self._create_config_combo(form, "Prompt Enhancer:", "enhancer_enabled", [("Off", 0), ("Florence 2 + Llama 3.2", 1), ("Florence 2 + Joy Caption (uncensored)", 2)], 0)
+ self._create_config_combo(form, "Enhancer Mode:", "enhancer_mode", [("Automatic on Generate", 0), ("On Demand Only", 1)], 0)
+ self._create_config_combo(form, "MMAudio:", "mmaudio_enabled", [("Off", 0), ("Enabled (unloaded after use)", 1), ("Enabled (persistent in RAM)", 2)], 0)
+ return tab
+
+ def _create_outputs_config_tab(self):
+ tab, form = self._create_scrollable_form_tab()
+ self._create_config_combo(form, "Video Codec:", "video_output_codec", [("x265 Balanced", 'libx265_28'), ("x264 Balanced", 'libx264_8'), ("x265 High Quality", 'libx265_8'), ("x264 High Quality", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], 'libx264_8')
+ self._create_config_combo(form, "Image Codec:", "image_output_codec", [("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], 'jpeg_95')
+ self._create_config_textbox(form, "Video Output Folder:", "save_path", "outputs")
+ self._create_config_textbox(form, "Image Output Folder:", "image_save_path", "outputs")
+ return tab
+
+ def _create_notifications_config_tab(self):
+ tab, form = self._create_scrollable_form_tab()
+ self._create_config_combo(form, "Notification Sound:", "notification_sound_enabled", [("On", 1), ("Off", 0)], 0)
+ self._create_config_slider(form, "Sound Volume:", "notification_sound_volume", 0, 100, 50, 5)
+ return tab
+
+ def init_wgp_state(self):
+ wgp = self.wgp
+ initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type)
+ dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
+ _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ all_model_ids = [m[1] for m in all_models]
+ if initial_model not in all_model_ids: initial_model = wgp.transformer_type
+ state_dict = {}
+ state_dict["model_filename"] = wgp.get_model_filename(initial_model, wgp.transformer_quantization, wgp.transformer_dtype_policy)
+ state_dict["model_type"] = initial_model
+ state_dict["advanced"] = wgp.advanced
+ state_dict["last_model_per_family"] = wgp.server_config.get("last_model_per_family", {})
+ state_dict["last_model_per_type"] = wgp.server_config.get("last_model_per_type", {})
+ state_dict["last_resolution_per_group"] = wgp.server_config.get("last_resolution_per_group", {})
+ state_dict["gen"] = {"queue": []}
+ self.state = state_dict
+ self.advanced_group.setChecked(wgp.advanced)
+ self.update_model_dropdowns(initial_model)
+ self.refresh_ui_from_model_change(initial_model)
+ self._update_input_visibility()
+
+ def update_model_dropdowns(self, current_model_type):
+ wgp = self.wgp
+ family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type)
+ for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]:
+ combo = self.widgets[combo_name]
+ combo.blockSignals(True)
+ combo.clear()
+ if mock.choices:
+ for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key)
+ index = combo.findData(mock.value)
+ if index != -1: combo.setCurrentIndex(index)
+
+ is_visible = True
+ if hasattr(mock, 'kwargs') and isinstance(mock.kwargs, dict):
+ is_visible = mock.kwargs.get('visible', True)
+ elif hasattr(mock, 'visible'):
+ is_visible = mock.visible
+ combo.setVisible(is_visible)
+
+ combo.blockSignals(False)
+
+ def refresh_ui_from_model_change(self, model_type):
+ """Update UI controls with default settings when the model is changed."""
+ wgp = self.wgp
+ self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode))
+ ui_defaults = wgp.get_default_settings(model_type)
+ wgp.set_model_settings(self.state, model_type, ui_defaults)
+
+ model_def = wgp.get_model_def(model_type)
+ base_model_type = wgp.get_base_model_type(model_type)
+ model_filename = self.state.get('model_filename', '')
+
+ image_outputs = model_def.get("image_outputs", False)
+ vace = wgp.test_vace_module(model_type)
+ t2v = base_model_type in ['t2v', 't2v_2_2']
+ i2v = wgp.test_class_i2v(model_type)
+ fantasy = base_model_type in ["fantasy"]
+ multitalk = model_def.get("multitalk_class", False)
+ any_audio_guidance = fantasy or multitalk
+ sliding_window_enabled = wgp.test_any_sliding_window(model_type)
+ recammaster = base_model_type in ["recam_1.3B"]
+ ltxv = "ltxv" in model_filename
+ diffusion_forcing = "diffusion_forcing" in model_filename
+ any_skip_layer_guidance = model_def.get("skip_layer_guidance", False)
+ any_cfg_zero = model_def.get("cfg_zero", False)
+ any_cfg_star = model_def.get("cfg_star", False)
+ any_apg = model_def.get("adaptive_projected_guidance", False)
+ v2i_switch_supported = model_def.get("v2i_switch_supported", False)
+
+ self._update_generation_mode_visibility(model_def)
+
+ for widget in self.widgets.values():
+ if hasattr(widget, 'blockSignals'): widget.blockSignals(True)
+
+ self.widgets['prompt'].setText(ui_defaults.get("prompt", ""))
+ self.widgets['negative_prompt'].setText(ui_defaults.get("negative_prompt", ""))
+ self.widgets['seed'].setText(str(ui_defaults.get("seed", -1)))
+
+ video_length_val = ui_defaults.get("video_length", 81)
+ self.widgets['video_length'].setValue(video_length_val)
+ self.widgets['video_length_label'].setText(str(video_length_val))
+
+ steps_val = ui_defaults.get("num_inference_steps", 30)
+ self.widgets['num_inference_steps'].setValue(steps_val)
+ self.widgets['num_inference_steps_label'].setText(str(steps_val))
+
+ self.widgets['resolution_group'].blockSignals(True)
+ self.widgets['resolution'].blockSignals(True)
+
+ current_res_choice = ui_defaults.get("resolution")
+ model_resolutions = model_def.get("resolutions", None)
+ self.full_resolution_choices, current_res_choice = wgp.get_resolution_choices(current_res_choice, model_resolutions)
+ available_groups, selected_group_resolutions, selected_group = wgp.group_resolutions(model_def, self.full_resolution_choices, current_res_choice)
+
+ self.widgets['resolution_group'].clear()
+ self.widgets['resolution_group'].addItems(available_groups)
+ group_index = self.widgets['resolution_group'].findText(selected_group)
+ if group_index != -1:
+ self.widgets['resolution_group'].setCurrentIndex(group_index)
+
+ self.widgets['resolution'].clear()
+ for label, value in selected_group_resolutions:
+ self.widgets['resolution'].addItem(label, value)
+ res_index = self.widgets['resolution'].findData(current_res_choice)
+ if res_index != -1:
+ self.widgets['resolution'].setCurrentIndex(res_index)
+
+ self.widgets['resolution_group'].blockSignals(False)
+ self.widgets['resolution'].blockSignals(False)
+
+ for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'image_refs', 'audio_source']:
+ if name in self.widgets: self.widgets[name].clear()
+
+ guidance_layout = self.widgets['guidance_layout']
+ guidance_max = model_def.get("guidance_max_phases", 1)
+ guidance_layout.setRowVisible(self.widgets['guidance_phases_row_index'], guidance_max > 1)
+
+ adv_general_layout = self.widgets['adv_general_layout']
+ adv_general_layout.setRowVisible(self.widgets['flow_shift_row_index'], not image_outputs)
+ adv_general_layout.setRowVisible(self.widgets['audio_guidance_row_index'], any_audio_guidance)
+ adv_general_layout.setRowVisible(self.widgets['repeat_generation_row_index'], not image_outputs)
+ adv_general_layout.setRowVisible(self.widgets['multi_images_gen_type_row_index'], i2v)
+
+ self.widgets['slg_group'].setVisible(any_skip_layer_guidance)
+ quality_form_layout = self.widgets['quality_form_layout']
+ quality_form_layout.setRowVisible(self.widgets['apg_switch_row_index'], any_apg)
+ quality_form_layout.setRowVisible(self.widgets['cfg_star_switch_row_index'], any_cfg_star)
+ quality_form_layout.setRowVisible(self.widgets['cfg_zero_step_row_index'], any_cfg_zero)
+ quality_form_layout.setRowVisible(self.widgets['min_frames_if_references_row_index'], v2i_switch_supported and image_outputs)
+
+ self.widgets['advanced_tabs'].setTabVisible(self.widgets['sliding_window_tab_index'], sliding_window_enabled and not image_outputs)
+
+ misc_layout = self.widgets['misc_layout']
+ misc_layout.setRowVisible(self.widgets['riflex_row_index'], not (recammaster or ltxv or diffusion_forcing))
+
+ index = self.widgets['multi_images_gen_type'].findData(ui_defaults.get('multi_images_gen_type', 0))
+ if index != -1: self.widgets['multi_images_gen_type'].setCurrentIndex(index)
+
+ guidance_val = ui_defaults.get("guidance_scale", 5.0)
+ self.widgets['guidance_scale'].setValue(int(guidance_val * 10))
+ self.widgets['guidance_scale_label'].setText(f"{guidance_val:.1f}")
+
+ guidance2_val = ui_defaults.get("guidance2_scale", 5.0)
+ self.widgets['guidance2_scale'].setValue(int(guidance2_val * 10))
+ self.widgets['guidance2_scale_label'].setText(f"{guidance2_val:.1f}")
+
+ guidance3_val = ui_defaults.get("guidance3_scale", 5.0)
+ self.widgets['guidance3_scale'].setValue(int(guidance3_val * 10))
+ self.widgets['guidance3_scale_label'].setText(f"{guidance3_val:.1f}")
+
+ self.widgets['guidance_phases'].clear()
+ if guidance_max >= 1: self.widgets['guidance_phases'].addItem("One Phase", 1)
+ if guidance_max >= 2: self.widgets['guidance_phases'].addItem("Two Phases", 2)
+ if guidance_max >= 3: self.widgets['guidance_phases'].addItem("Three Phases", 3)
+ index = self.widgets['guidance_phases'].findData(ui_defaults.get("guidance_phases", 1))
+ if index != -1: self.widgets['guidance_phases'].setCurrentIndex(index)
+
+ switch_thresh_val = ui_defaults.get("switch_threshold", 0)
+ self.widgets['switch_threshold'].setValue(switch_thresh_val)
+ self.widgets['switch_threshold_label'].setText(str(switch_thresh_val))
+
+ nag_scale_val = ui_defaults.get('NAG_scale', 1.0)
+ self.widgets['NAG_scale'].setValue(int(nag_scale_val * 10))
+ self.widgets['NAG_scale_label'].setText(f"{nag_scale_val:.1f}")
+
+ nag_tau_val = ui_defaults.get('NAG_tau', 3.5)
+ self.widgets['NAG_tau'].setValue(int(nag_tau_val * 10))
+ self.widgets['NAG_tau_label'].setText(f"{nag_tau_val:.1f}")
+
+ nag_alpha_val = ui_defaults.get('NAG_alpha', 0.5)
+ self.widgets['NAG_alpha'].setValue(int(nag_alpha_val * 10))
+ self.widgets['NAG_alpha_label'].setText(f"{nag_alpha_val:.1f}")
+
+ self.widgets['nag_group'].setVisible(vace or t2v or i2v)
+
+ self.widgets['sample_solver'].clear()
+ sampler_choices = model_def.get("sample_solvers", [])
+ self.widgets['solver_row_container'].setVisible(bool(sampler_choices))
+ if sampler_choices:
+ for label, value in sampler_choices: self.widgets['sample_solver'].addItem(label, value)
+ solver_val = ui_defaults.get('sample_solver', sampler_choices[0][1])
+ index = self.widgets['sample_solver'].findData(solver_val)
+ if index != -1: self.widgets['sample_solver'].setCurrentIndex(index)
+
+ flow_val = ui_defaults.get("flow_shift", 3.0)
+ self.widgets['flow_shift'].setValue(int(flow_val * 10))
+ self.widgets['flow_shift_label'].setText(f"{flow_val:.1f}")
+
+ audio_guidance_val = ui_defaults.get("audio_guidance_scale", 4.0)
+ self.widgets['audio_guidance_scale'].setValue(int(audio_guidance_val * 10))
+ self.widgets['audio_guidance_scale_label'].setText(f"{audio_guidance_val:.1f}")
+
+ repeat_val = ui_defaults.get("repeat_generation", 1)
+ self.widgets['repeat_generation'].setValue(repeat_val)
+ self.widgets['repeat_generation_label'].setText(str(repeat_val))
+
+ available_loras, _, _, _, _, _ = wgp.setup_loras(model_type, None, wgp.get_lora_dir(model_type), "")
+ self.state['loras'] = available_loras
+ self.lora_map = {os.path.basename(p): p for p in available_loras}
+ lora_list_widget = self.widgets['activated_loras']
+ lora_list_widget.clear()
+ lora_list_widget.addItems(sorted(self.lora_map.keys()))
+ selected_loras = ui_defaults.get('activated_loras', [])
+ for i in range(lora_list_widget.count()):
+ item = lora_list_widget.item(i)
+ if any(item.text() == os.path.basename(p) for p in selected_loras): item.setSelected(True)
+ self.widgets['loras_multipliers'].setText(ui_defaults.get('loras_multipliers', ''))
+
+ skip_cache_val = ui_defaults.get('skip_steps_cache_type', "")
+ index = self.widgets['skip_steps_cache_type'].findData(skip_cache_val)
+ if index != -1: self.widgets['skip_steps_cache_type'].setCurrentIndex(index)
+
+ skip_mult = ui_defaults.get('skip_steps_multiplier', 1.5)
+ index = self.widgets['skip_steps_multiplier'].findData(skip_mult)
+ if index != -1: self.widgets['skip_steps_multiplier'].setCurrentIndex(index)
+
+ skip_perc_val = ui_defaults.get('skip_steps_start_step_perc', 0)
+ self.widgets['skip_steps_start_step_perc'].setValue(skip_perc_val)
+ self.widgets['skip_steps_start_step_perc_label'].setText(str(skip_perc_val))
+
+ temp_up_val = ui_defaults.get('temporal_upsampling', "")
+ index = self.widgets['temporal_upsampling'].findData(temp_up_val)
+ if index != -1: self.widgets['temporal_upsampling'].setCurrentIndex(index)
+
+ spat_up_val = ui_defaults.get('spatial_upsampling', "")
+ index = self.widgets['spatial_upsampling'].findData(spat_up_val)
+ if index != -1: self.widgets['spatial_upsampling'].setCurrentIndex(index)
+
+ film_grain_i = ui_defaults.get('film_grain_intensity', 0)
+ self.widgets['film_grain_intensity'].setValue(int(film_grain_i * 100))
+ self.widgets['film_grain_intensity_label'].setText(f"{film_grain_i:.2f}")
+
+ film_grain_s = ui_defaults.get('film_grain_saturation', 0.5)
+ self.widgets['film_grain_saturation'].setValue(int(film_grain_s * 100))
+ self.widgets['film_grain_saturation_label'].setText(f"{film_grain_s:.2f}")
+
+ self.widgets['MMAudio_setting'].setCurrentIndex(ui_defaults.get('MMAudio_setting', 0))
+ self.widgets['MMAudio_prompt'].setText(ui_defaults.get('MMAudio_prompt', ''))
+ self.widgets['MMAudio_neg_prompt'].setText(ui_defaults.get('MMAudio_neg_prompt', ''))
+
+ self.widgets['slg_switch'].setCurrentIndex(ui_defaults.get('slg_switch', 0))
+ slg_start_val = ui_defaults.get('slg_start_perc', 10)
+ self.widgets['slg_start_perc'].setValue(slg_start_val)
+ self.widgets['slg_start_perc_label'].setText(str(slg_start_val))
+ slg_end_val = ui_defaults.get('slg_end_perc', 90)
+ self.widgets['slg_end_perc'].setValue(slg_end_val)
+ self.widgets['slg_end_perc_label'].setText(str(slg_end_val))
+
+ self.widgets['apg_switch'].setCurrentIndex(ui_defaults.get('apg_switch', 0))
+ self.widgets['cfg_star_switch'].setCurrentIndex(ui_defaults.get('cfg_star_switch', 0))
+
+ cfg_zero_val = ui_defaults.get('cfg_zero_step', -1)
+ self.widgets['cfg_zero_step'].setValue(cfg_zero_val)
+ self.widgets['cfg_zero_step_label'].setText(str(cfg_zero_val))
+
+ min_frames_val = ui_defaults.get('min_frames_if_references', 1)
+ index = self.widgets['min_frames_if_references'].findData(min_frames_val)
+ if index != -1: self.widgets['min_frames_if_references'].setCurrentIndex(index)
+
+ self.widgets['RIFLEx_setting'].setCurrentIndex(ui_defaults.get('RIFLEx_setting', 0))
+
+ fps = wgp.get_model_fps(model_type)
+ force_fps_choices = [
+ (f"Model Default ({fps} fps)", ""), ("Auto", "auto"), ("Control Video fps", "control"),
+ ("Source Video fps", "source"), ("15", "15"), ("16", "16"), ("23", "23"),
+ ("24", "24"), ("25", "25"), ("30", "30")
+ ]
+ self.widgets['force_fps'].clear()
+ for label, value in force_fps_choices: self.widgets['force_fps'].addItem(label, value)
+ force_fps_val = ui_defaults.get('force_fps', "")
+ index = self.widgets['force_fps'].findData(force_fps_val)
+ if index != -1: self.widgets['force_fps'].setCurrentIndex(index)
+
+ override_prof_val = ui_defaults.get('override_profile', -1)
+ index = self.widgets['override_profile'].findData(override_prof_val)
+ if index != -1: self.widgets['override_profile'].setCurrentIndex(index)
+
+ self.widgets['multi_prompts_gen_type'].setCurrentIndex(ui_defaults.get('multi_prompts_gen_type', 0))
+
+ denoising_val = ui_defaults.get("denoising_strength", 0.5)
+ self.widgets['denoising_strength'].setValue(int(denoising_val * 100))
+ self.widgets['denoising_strength_label'].setText(f"{denoising_val:.2f}")
+
+ sw_size = ui_defaults.get("sliding_window_size", 129)
+ self.widgets['sliding_window_size'].setValue(sw_size)
+ self.widgets['sliding_window_size_label'].setText(str(sw_size))
+
+ sw_overlap = ui_defaults.get("sliding_window_overlap", 5)
+ self.widgets['sliding_window_overlap'].setValue(sw_overlap)
+ self.widgets['sliding_window_overlap_label'].setText(str(sw_overlap))
+
+ sw_color = ui_defaults.get("sliding_window_color_correction_strength", 0)
+ self.widgets['sliding_window_color_correction_strength'].setValue(int(sw_color * 100))
+ self.widgets['sliding_window_color_correction_strength_label'].setText(f"{sw_color:.2f}")
+
+ sw_noise = ui_defaults.get("sliding_window_overlap_noise", 20)
+ self.widgets['sliding_window_overlap_noise'].setValue(sw_noise)
+ self.widgets['sliding_window_overlap_noise_label'].setText(str(sw_noise))
+
+ sw_discard = ui_defaults.get("sliding_window_discard_last_frames", 0)
+ self.widgets['sliding_window_discard_last_frames'].setValue(sw_discard)
+ self.widgets['sliding_window_discard_last_frames_label'].setText(str(sw_discard))
+
+ for widget in self.widgets.values():
+ if hasattr(widget, 'blockSignals'): widget.blockSignals(False)
+
+ self._update_dynamic_ui()
+ self._update_input_visibility()
+
+ def _update_dynamic_ui(self):
+ phases = self.widgets['guidance_phases'].currentData() or 1
+ guidance_layout = self.widgets['guidance_layout']
+ guidance_layout.setRowVisible(self.widgets['guidance2_row_index'], phases >= 2)
+ guidance_layout.setRowVisible(self.widgets['guidance3_row_index'], phases >= 3)
+ guidance_layout.setRowVisible(self.widgets['switch_thresh_row_index'], phases >= 2)
+
+ def _update_generation_mode_visibility(self, model_def):
+ allowed = model_def.get("image_prompt_types_allowed", "")
+ choices = []
+ if "T" in allowed or not allowed: choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T"))
+ if "S" in allowed: choices.append(("Start Video with Image", "S"))
+ if "V" in allowed: choices.append(("Continue Video", "V"))
+ if "L" in allowed: choices.append(("Continue Last Video", "L"))
+ button_map = { "T": self.widgets['mode_t'], "S": self.widgets['mode_s'], "V": self.widgets['mode_v'], "L": self.widgets['mode_l'] }
+ for btn in button_map.values(): btn.setVisible(False)
+ allowed_values = [c[1] for c in choices]
+ for label, value in choices:
+ if value in button_map:
+ btn = button_map[value]
+ btn.setText(label)
+ btn.setVisible(True)
+ current_checked_value = next((value for value, btn in button_map.items() if btn.isChecked()), None)
+ if current_checked_value is None or not button_map[current_checked_value].isVisible():
+ if allowed_values: button_map[allowed_values[0]].setChecked(True)
+ end_image_visible = "E" in allowed
+ self.widgets['image_end_checkbox'].setVisible(end_image_visible)
+ if not end_image_visible: self.widgets['image_end_checkbox'].setChecked(False)
+ control_video_visible = model_def.get("guide_preprocessing") is not None
+ self.widgets['control_video_checkbox'].setVisible(control_video_visible)
+ if not control_video_visible: self.widgets['control_video_checkbox'].setChecked(False)
+ ref_image_visible = model_def.get("image_ref_choices") is not None
+ self.widgets['ref_image_checkbox'].setVisible(ref_image_visible)
+ if not ref_image_visible: self.widgets['ref_image_checkbox'].setChecked(False)
+
+ def _update_input_visibility(self):
+ is_s_mode = self.widgets['mode_s'].isChecked()
+ is_v_mode = self.widgets['mode_v'].isChecked()
+ is_l_mode = self.widgets['mode_l'].isChecked()
+ use_end = self.widgets['image_end_checkbox'].isChecked() and self.widgets['image_end_checkbox'].isVisible()
+ use_control = self.widgets['control_video_checkbox'].isChecked() and self.widgets['control_video_checkbox'].isVisible()
+ use_ref = self.widgets['ref_image_checkbox'].isChecked() and self.widgets['ref_image_checkbox'].isVisible()
+ self.widgets['image_start_container'].setVisible(is_s_mode)
+ self.widgets['video_source_container'].setVisible(is_v_mode)
+ end_checkbox_enabled = is_s_mode or is_v_mode or is_l_mode
+ self.widgets['image_end_checkbox'].setEnabled(end_checkbox_enabled)
+ self.widgets['image_end_container'].setVisible(use_end and end_checkbox_enabled)
+ self.widgets['video_guide_container'].setVisible(use_control)
+ self.widgets['video_mask_container'].setVisible(use_control)
+ self.widgets['image_refs_container'].setVisible(use_ref)
+
+ def connect_signals(self):
+ self.widgets['model_family'].currentIndexChanged.connect(self._on_family_changed)
+ self.widgets['model_base_type_choice'].currentIndexChanged.connect(self._on_base_type_changed)
+ self.widgets['model_choice'].currentIndexChanged.connect(self._on_model_changed)
+ self.widgets['resolution_group'].currentIndexChanged.connect(self._on_resolution_group_changed)
+ self.widgets['guidance_phases'].currentIndexChanged.connect(self._update_dynamic_ui)
+ self.widgets['mode_t'].toggled.connect(self._update_input_visibility)
+ self.widgets['mode_s'].toggled.connect(self._update_input_visibility)
+ self.widgets['mode_v'].toggled.connect(self._update_input_visibility)
+ self.widgets['mode_l'].toggled.connect(self._update_input_visibility)
+ self.widgets['image_end_checkbox'].toggled.connect(self._update_input_visibility)
+ self.widgets['control_video_checkbox'].toggled.connect(self._update_input_visibility)
+ self.widgets['ref_image_checkbox'].toggled.connect(self._update_input_visibility)
+ self.widgets['preview_group'].toggled.connect(self._on_preview_toggled)
+ self.generate_btn.clicked.connect(self._on_generate)
+ self.add_to_queue_btn.clicked.connect(self._on_add_to_queue)
+ self.remove_queue_btn.clicked.connect(self._on_remove_selected_from_queue)
+ self.clear_queue_btn.clicked.connect(self._on_clear_queue)
+ self.abort_btn.clicked.connect(self._on_abort)
+ self.queue_table.rowsMoved.connect(self._on_queue_rows_moved)
+
+ def load_main_config(self):
+ try:
+ with open('main_config.json', 'r') as f: self.main_config = json.load(f)
+ except (FileNotFoundError, json.JSONDecodeError): self.main_config = {'preview_visible': False}
+
+ def save_main_config(self):
+ try:
+ with open('main_config.json', 'w') as f: json.dump(self.main_config, f, indent=4)
+ except Exception as e: print(f"Error saving main_config.json: {e}")
+
+ def apply_initial_config(self):
+ is_visible = self.main_config.get('preview_visible', True)
+ self.widgets['preview_group'].setChecked(is_visible)
+ self.widgets['preview_image'].setVisible(is_visible)
+
+ def _on_preview_toggled(self, checked):
+ self.widgets['preview_image'].setVisible(checked)
+ self.main_config['preview_visible'] = checked
+ self.save_main_config()
+
+ def _on_family_changed(self):
+ family = self.widgets['model_family'].currentData()
+ if not family or not self.state: return
+ base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family)
+
+ if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict):
+ is_visible_base = base_type_mock.kwargs.get('visible', True)
+ elif hasattr(base_type_mock, 'visible'):
+ is_visible_base = base_type_mock.visible
+ else:
+ is_visible_base = True
+
+ self.widgets['model_base_type_choice'].blockSignals(True)
+ self.widgets['model_base_type_choice'].clear()
+ if base_type_mock.choices:
+ for label, value in base_type_mock.choices: self.widgets['model_base_type_choice'].addItem(label, value)
+ self.widgets['model_base_type_choice'].setCurrentIndex(self.widgets['model_base_type_choice'].findData(base_type_mock.value))
+ self.widgets['model_base_type_choice'].setVisible(is_visible_base)
+ self.widgets['model_base_type_choice'].blockSignals(False)
+
+ if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict):
+ is_visible_choice = choice_mock.kwargs.get('visible', True)
+ elif hasattr(choice_mock, 'visible'):
+ is_visible_choice = choice_mock.visible
+ else:
+ is_visible_choice = True
+
+ self.widgets['model_choice'].blockSignals(True)
+ self.widgets['model_choice'].clear()
+ if choice_mock.choices:
+ for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value)
+ self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value))
+ self.widgets['model_choice'].setVisible(is_visible_choice)
+ self.widgets['model_choice'].blockSignals(False)
+
+ self._on_model_changed()
+
+ def _on_base_type_changed(self):
+ family = self.widgets['model_family'].currentData()
+ base_type = self.widgets['model_base_type_choice'].currentData()
+ if not family or not base_type or not self.state: return
+ base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type)
+
+ if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict):
+ is_visible_choice = choice_mock.kwargs.get('visible', True)
+ elif hasattr(choice_mock, 'visible'):
+ is_visible_choice = choice_mock.visible
+ else:
+ is_visible_choice = True
+
+ self.widgets['model_choice'].blockSignals(True)
+ self.widgets['model_choice'].clear()
+ if choice_mock.choices:
+ for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value)
+ self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value))
+ self.widgets['model_choice'].setVisible(is_visible_choice)
+ self.widgets['model_choice'].blockSignals(False)
+ self._on_model_changed()
+
+ def _on_model_changed(self):
+ model_type = self.widgets['model_choice'].currentData()
+ if not model_type or model_type == self.state.get('model_type'): return
+ self.wgp.change_model(self.state, model_type)
+ self.refresh_ui_from_model_change(model_type)
+
+ def _on_resolution_group_changed(self):
+ selected_group = self.widgets['resolution_group'].currentText()
+ if not selected_group or not hasattr(self, 'full_resolution_choices'): return
+ model_type = self.state['model_type']
+ model_def = self.wgp.get_model_def(model_type)
+ model_resolutions = model_def.get("resolutions", None)
+ group_resolution_choices = []
+ if model_resolutions is None:
+ group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group]
+ else: return
+ last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "")
+ if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices:
+ last_resolution = group_resolution_choices[0][1]
+ self.widgets['resolution'].blockSignals(True)
+ self.widgets['resolution'].clear()
+ for label, value in group_resolution_choices: self.widgets['resolution'].addItem(label, value)
+ self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution))
+ self.widgets['resolution'].blockSignals(False)
+
+ def collect_inputs(self):
+ full_inputs = self.wgp.get_current_model_settings(self.state).copy()
+ full_inputs['lset_name'] = ""
+ full_inputs['image_mode'] = 0
+ expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, }
+ for key, default_value in expected_keys.items():
+ if key not in full_inputs: full_inputs[key] = default_value
+ full_inputs['prompt'] = self.widgets['prompt'].toPlainText()
+ full_inputs['negative_prompt'] = self.widgets['negative_prompt'].toPlainText()
+ full_inputs['resolution'] = self.widgets['resolution'].currentData()
+ full_inputs['video_length'] = self.widgets['video_length'].value()
+ full_inputs['num_inference_steps'] = self.widgets['num_inference_steps'].value()
+ full_inputs['seed'] = int(self.widgets['seed'].text())
+ image_prompt_type = ""
+ video_prompt_type = ""
+ if self.widgets['mode_s'].isChecked(): image_prompt_type = 'S'
+ elif self.widgets['mode_v'].isChecked(): image_prompt_type = 'V'
+ elif self.widgets['mode_l'].isChecked(): image_prompt_type = 'L'
+ if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): image_prompt_type += 'E'
+ if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): video_prompt_type += 'V'
+ if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): video_prompt_type += 'I'
+ full_inputs['image_prompt_type'] = image_prompt_type
+ full_inputs['video_prompt_type'] = video_prompt_type
+ for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']:
+ if name in self.widgets: full_inputs[name] = self.widgets[name].text() or None
+ paths = self.widgets['image_refs'].text().split(';')
+ full_inputs['image_refs'] = [p.strip() for p in paths if p.strip()] if paths and paths[0] else None
+ full_inputs['denoising_strength'] = self.widgets['denoising_strength'].value() / 100.0
+ if self.advanced_group.isChecked():
+ full_inputs['guidance_scale'] = self.widgets['guidance_scale'].value() / 10.0
+ full_inputs['guidance_phases'] = self.widgets['guidance_phases'].currentData()
+ full_inputs['guidance2_scale'] = self.widgets['guidance2_scale'].value() / 10.0
+ full_inputs['guidance3_scale'] = self.widgets['guidance3_scale'].value() / 10.0
+ full_inputs['switch_threshold'] = self.widgets['switch_threshold'].value()
+ full_inputs['NAG_scale'] = self.widgets['NAG_scale'].value() / 10.0
+ full_inputs['NAG_tau'] = self.widgets['NAG_tau'].value() / 10.0
+ full_inputs['NAG_alpha'] = self.widgets['NAG_alpha'].value() / 10.0
+ full_inputs['sample_solver'] = self.widgets['sample_solver'].currentData()
+ full_inputs['flow_shift'] = self.widgets['flow_shift'].value() / 10.0
+ full_inputs['audio_guidance_scale'] = self.widgets['audio_guidance_scale'].value() / 10.0
+ full_inputs['repeat_generation'] = self.widgets['repeat_generation'].value()
+ full_inputs['multi_images_gen_type'] = self.widgets['multi_images_gen_type'].currentData()
+ selected_items = self.widgets['activated_loras'].selectedItems()
+ full_inputs['activated_loras'] = [self.lora_map[item.text()] for item in selected_items if item.text() in self.lora_map]
+ full_inputs['loras_multipliers'] = self.widgets['loras_multipliers'].toPlainText()
+ full_inputs['skip_steps_cache_type'] = self.widgets['skip_steps_cache_type'].currentData()
+ full_inputs['skip_steps_multiplier'] = self.widgets['skip_steps_multiplier'].currentData()
+ full_inputs['skip_steps_start_step_perc'] = self.widgets['skip_steps_start_step_perc'].value()
+ full_inputs['temporal_upsampling'] = self.widgets['temporal_upsampling'].currentData()
+ full_inputs['spatial_upsampling'] = self.widgets['spatial_upsampling'].currentData()
+ full_inputs['film_grain_intensity'] = self.widgets['film_grain_intensity'].value() / 100.0
+ full_inputs['film_grain_saturation'] = self.widgets['film_grain_saturation'].value() / 100.0
+ full_inputs['MMAudio_setting'] = self.widgets['MMAudio_setting'].currentData()
+ full_inputs['MMAudio_prompt'] = self.widgets['MMAudio_prompt'].text()
+ full_inputs['MMAudio_neg_prompt'] = self.widgets['MMAudio_neg_prompt'].text()
+ full_inputs['RIFLEx_setting'] = self.widgets['RIFLEx_setting'].currentData()
+ full_inputs['force_fps'] = self.widgets['force_fps'].currentData()
+ full_inputs['override_profile'] = self.widgets['override_profile'].currentData()
+ full_inputs['multi_prompts_gen_type'] = self.widgets['multi_prompts_gen_type'].currentData()
+ full_inputs['slg_switch'] = self.widgets['slg_switch'].currentData()
+ full_inputs['slg_start_perc'] = self.widgets['slg_start_perc'].value()
+ full_inputs['slg_end_perc'] = self.widgets['slg_end_perc'].value()
+ full_inputs['apg_switch'] = self.widgets['apg_switch'].currentData()
+ full_inputs['cfg_star_switch'] = self.widgets['cfg_star_switch'].currentData()
+ full_inputs['cfg_zero_step'] = self.widgets['cfg_zero_step'].value()
+ full_inputs['min_frames_if_references'] = self.widgets['min_frames_if_references'].currentData()
+ full_inputs['sliding_window_size'] = self.widgets['sliding_window_size'].value()
+ full_inputs['sliding_window_overlap'] = self.widgets['sliding_window_overlap'].value()
+ full_inputs['sliding_window_color_correction_strength'] = self.widgets['sliding_window_color_correction_strength'].value() / 100.0
+ full_inputs['sliding_window_overlap_noise'] = self.widgets['sliding_window_overlap_noise'].value()
+ full_inputs['sliding_window_discard_last_frames'] = self.widgets['sliding_window_discard_last_frames'].value()
+ return full_inputs
+
+ def _prepare_state_for_generation(self):
+ if 'gen' in self.state:
+ self.state['gen'].pop('abort', None)
+ self.state['gen'].pop('in_progress', None)
+
+ def _on_generate(self):
+ try:
+ is_running = self.thread and self.thread.isRunning()
+ self._add_task_to_queue()
+ if not is_running: self.start_generation()
+ except Exception as e:
+ import traceback; traceback.print_exc()
+
+ def _on_add_to_queue(self):
+ try:
+ self._add_task_to_queue()
+ except Exception as e:
+ import traceback; traceback.print_exc()
+
+ def _add_task_to_queue(self):
+ all_inputs = self.collect_inputs()
+ for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None)
+ all_inputs['state'] = self.state
+ self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs)
+ self.state["validate_success"] = 1
+ self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type'])
+ self.update_queue_table()
+
+ def start_generation(self):
+ if not self.state['gen']['queue']: return
+ self._prepare_state_for_generation()
+ self.generate_btn.setEnabled(False)
+ self.add_to_queue_btn.setEnabled(True)
+ self.thread = QThread()
+ self.worker = Worker(self.plugin, self.state)
+ self.worker.moveToThread(self.thread)
+ self.thread.started.connect(self.worker.run)
+ self.worker.finished.connect(self.thread.quit)
+ self.worker.finished.connect(self.worker.deleteLater)
+ self.thread.finished.connect(self.thread.deleteLater)
+ self.thread.finished.connect(self.on_generation_finished)
+ self.worker.status.connect(self.status_label.setText)
+ self.worker.progress.connect(self.update_progress)
+ self.worker.preview.connect(self.update_preview)
+ self.worker.output.connect(self.update_queue_and_results)
+ self.worker.error.connect(self.on_generation_error)
+ self.thread.start()
+ self.update_queue_table()
+
+ def on_generation_finished(self):
+ time.sleep(0.1)
+ self.status_label.setText("Finished.")
+ self.progress_bar.setValue(0)
+ self.generate_btn.setEnabled(True)
+ self.add_to_queue_btn.setEnabled(False)
+ self.thread = None; self.worker = None
+ self.update_queue_table()
+
+ def on_generation_error(self, err_msg):
+ QMessageBox.critical(self, "Generation Error", str(err_msg))
+ self.on_generation_finished()
+
+ def update_progress(self, data):
+ if len(data) > 1 and isinstance(data[0], tuple):
+ step, total = data[0]
+ self.progress_bar.setMaximum(total)
+ self.progress_bar.setValue(step)
+ self.status_label.setText(str(data[1]))
+ if step <= 1: self.update_queue_table()
+ elif len(data) > 1: self.status_label.setText(str(data[1]))
+
+ def update_preview(self, pil_image):
+ if pil_image and self.widgets['preview_group'].isChecked():
+ q_image = ImageQt(pil_image)
+ pixmap = QPixmap.fromImage(q_image)
+ self.preview_image.setPixmap(pixmap.scaled(self.preview_image.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation))
+
+ def update_queue_and_results(self):
+ self.update_queue_table()
+ file_list = self.state.get('gen', {}).get('file_list', [])
+ for file_path in file_list:
+ if file_path not in self.processed_files:
+ self.add_result_item(file_path)
+ self.processed_files.add(file_path)
+
+ def add_result_item(self, video_path):
+ item_widget = VideoResultItemWidget(video_path, self.plugin)
+ list_item = QListWidgetItem(self.results_list)
+ list_item.setSizeHint(item_widget.sizeHint())
+ self.results_list.addItem(list_item)
+ self.results_list.setItemWidget(list_item, item_widget)
+
+ def update_queue_table(self):
+ with self.wgp.lock:
+ queue = self.state.get('gen', {}).get('queue', [])
+ is_running = self.thread and self.thread.isRunning()
+ queue_to_display = queue if is_running else [None] + queue
+ table_data = self.wgp.get_queue_table(queue_to_display)
+ self.queue_table.setRowCount(0)
+ self.queue_table.setRowCount(len(table_data))
+ self.queue_table.setColumnCount(4)
+ self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"])
+ for row_idx, row_data in enumerate(table_data):
+ prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1])
+ for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]):
+ self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data)))
+ self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+ self.queue_table.resizeColumnsToContents()
+
+ def _on_remove_selected_from_queue(self):
+ selected_row = self.queue_table.currentRow()
+ if selected_row < 0: return
+ with self.wgp.lock:
+ is_running = self.thread and self.thread.isRunning()
+ offset = 1 if is_running else 0
+ queue = self.state.get('gen', {}).get('queue', [])
+ if len(queue) > selected_row + offset: queue.pop(selected_row + offset)
+ self.update_queue_table()
+
+ def _on_queue_rows_moved(self, source_row, dest_row):
+ with self.wgp.lock:
+ queue = self.state.get('gen', {}).get('queue', [])
+ is_running = self.thread and self.thread.isRunning()
+ offset = 1 if is_running else 0
+ real_source_idx = source_row + offset
+ real_dest_idx = dest_row + offset
+ moved_item = queue.pop(real_source_idx)
+ queue.insert(real_dest_idx, moved_item)
+ self.update_queue_table()
+
+ def _on_clear_queue(self):
+ self.wgp.clear_queue_action(self.state)
+ self.update_queue_table()
+
+ def _on_abort(self):
+ if self.worker:
+ self.wgp.abort_generation(self.state)
+ self.status_label.setText("Aborting...")
+ self.worker._is_running = False
+
+ def _on_release_ram(self):
+ self.wgp.release_RAM()
+ QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.")
+
+ def _on_apply_config_changes(self):
+ changes = {}
+ list_widget = self.widgets['config_transformer_types']
+ changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
+ list_widget = self.widgets['config_preload_model_policy']
+ changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
+ changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData()
+ changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText()
+ for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]:
+ changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
+ for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]:
+ changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
+ changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value()
+ for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]:
+ changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
+ for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]:
+ widget = self.widgets[f'config_{key}']
+ changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text()
+ changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData()
+ changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value()
+ changes['last_resolution_choice'] = self.widgets['resolution'].currentData()
+ try:
+ msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes)
+ self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.")
+ self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode))
+ if family_mock.choices is not None or choice_mock.choices is not None:
+ self.update_model_dropdowns(self.wgp.transformer_type)
+ self.refresh_ui_from_model_change(self.wgp.transformer_type)
+ except Exception as e:
+ self.config_status_label.setText(f"Error applying changes: {e}")
+ import traceback; traceback.print_exc()
+
+class Plugin(VideoEditorPlugin):
+ def initialize(self):
+ self.name = "AI Generator"
+ self.description = "Uses the integrated Wan2GP library to generate video clips."
+ self.client_widget = WgpDesktopPluginWidget(self)
+ self.dock_widget = None
+ self.active_region = None; self.temp_dir = None
+ self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None
+
+ def enable(self):
+ if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name)
+ self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu)
+ self.app.status_label.setText(f"{self.name}: Enabled.")
+
+ def disable(self):
+ try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu)
+ except TypeError: pass
+ self._cleanup_temp_dir()
+ if self.client_widget.worker: self.client_widget._on_abort()
+ self.app.status_label.setText(f"{self.name}: Disabled.")
+
+ def _cleanup_temp_dir(self):
+ if self.temp_dir and os.path.exists(self.temp_dir):
+ shutil.rmtree(self.temp_dir)
+ self.temp_dir = None
+
+ def _reset_state(self):
+ self.active_region = None; self.insert_on_new_track = False
+ self.start_frame_path = None; self.end_frame_path = None
+ self.client_widget.processed_files.clear()
+ self.client_widget.results_list.clear()
+ self.client_widget.widgets['image_start'].clear()
+ self.client_widget.widgets['image_end'].clear()
+ self.client_widget.widgets['video_source'].clear()
+ self._cleanup_temp_dir()
+ self.app.status_label.setText(f"{self.name}: Ready.")
+
+ def on_timeline_context_menu(self, menu, event):
+ region = self.app.timeline_widget.get_region_at_pos(event.pos())
+ if region:
+ menu.addSeparator()
+ start_sec, end_sec = region
+ start_data, _, _ = self.app.get_frame_data_at_time(start_sec)
+ end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+ if start_data and end_data:
+ join_action = menu.addAction("Join Frames With AI")
+ join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
+ join_action_new_track = menu.addAction("Join Frames With AI (New Track)")
+ join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
+ create_action = menu.addAction("Create Frames With AI")
+ create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False))
+ create_action_new_track = menu.addAction("Create Frames With AI (New Track)")
+ create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True))
+
+ def setup_generator_for_region(self, region, on_new_track=False):
+ self._reset_state()
+ self.active_region = region
+ self.insert_on_new_track = on_new_track
+
+ model_to_set = 'i2v_2_2'
+ dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
+ _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ if any(model_to_set == m[1] for m in all_models):
+ if self.client_widget.state.get('model_type') != model_to_set:
+ self.client_widget.update_model_dropdowns(model_to_set)
+ self.client_widget._on_model_changed()
+ else:
+ print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
+
+ start_sec, end_sec = region
+ start_data, w, h = self.app.get_frame_data_at_time(start_sec)
+ end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+ if not start_data or not end_data:
+ QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.")
+ return
+ try:
+ self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_")
+ self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png")
+ self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png")
+ QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path)
+ QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
+
+ duration_sec = end_sec - start_sec
+ wgp = self.client_widget.wgp
+ model_type = self.client_widget.state['model_type']
+ fps = wgp.get_model_fps(model_type)
+ video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+
+ self.client_widget.widgets['video_length'].setValue(video_length_frames)
+ self.client_widget.widgets['mode_s'].setChecked(True)
+ self.client_widget.widgets['image_end_checkbox'].setChecked(True)
+ self.client_widget.widgets['image_start'].setText(self.start_frame_path)
+ self.client_widget.widgets['image_end'].setText(self.end_frame_path)
+
+ self.client_widget._update_input_visibility()
+
+ except Exception as e:
+ QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}")
+ self._cleanup_temp_dir()
+ return
+ self.app.status_label.setText(f"Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.")
+ self.dock_widget.show()
+ self.dock_widget.raise_()
+
+ def setup_creator_for_region(self, region, on_new_track=False):
+ self._reset_state()
+ self.active_region = region
+ self.insert_on_new_track = on_new_track
+
+ model_to_set = 't2v_2_2'
+ dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
+ _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ if any(model_to_set == m[1] for m in all_models):
+ if self.client_widget.state.get('model_type') != model_to_set:
+ self.client_widget.update_model_dropdowns(model_to_set)
+ self.client_widget._on_model_changed()
+ else:
+ print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.")
+
+ start_sec, end_sec = region
+ duration_sec = end_sec - start_sec
+ wgp = self.client_widget.wgp
+ model_type = self.client_widget.state['model_type']
+ fps = wgp.get_model_fps(model_type)
+ video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+
+ self.client_widget.widgets['video_length'].setValue(video_length_frames)
+ self.client_widget.widgets['mode_t'].setChecked(True)
+
+ self.app.status_label.setText(f"Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.")
+ self.dock_widget.show()
+ self.dock_widget.raise_()
+
+ def insert_generated_clip(self, video_path):
+ from videoeditor import TimelineClip
+ if not self.active_region:
+ self.app.status_label.setText("Error: No active region to insert into."); return
+ if not os.path.exists(video_path):
+ self.app.status_label.setText(f"Error: Output file not found: {video_path}"); return
+ start_sec, end_sec = self.active_region
+ def complex_insertion_action():
+ self.app._add_media_files_to_project([video_path])
+ media_info = self.app.media_properties.get(video_path)
+ if not media_info: raise ValueError("Could not probe inserted clip.")
+ actual_duration, has_audio = media_info['duration'], media_info['has_audio']
+ if self.insert_on_new_track:
+ self.app.timeline.num_video_tracks += 1
+ video_track_index = self.app.timeline.num_video_tracks
+ audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None
+ else:
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
+ clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec]
+ for clip in clips_to_remove:
+ if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip)
+ video_track_index, audio_track_index = 1, 1 if has_audio else None
+ group_id = str(uuid.uuid4())
+ new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id)
+ self.app.timeline.add_clip(new_clip)
+ if audio_track_index:
+ if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index
+ audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id)
+ self.app.timeline.add_clip(audio_clip)
+ try:
+ self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action)
+ self.app.prune_empty_tracks()
+ self.app.status_label.setText("AI clip inserted successfully.")
+ for i in range(self.client_widget.results_list.count()):
+ widget = self.client_widget.results_list.itemWidget(self.client_widget.results_list.item(i))
+ if widget and widget.video_path == video_path:
+ self.client_widget.results_list.takeItem(i); break
+ except Exception as e:
+ import traceback; traceback.print_exc()
+ self.app.status_label.setText(f"Error during clip insertion: {e}")
\ No newline at end of file
From 556280398351454626f953169970dff560131556 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sat, 11 Oct 2025 23:39:55 +1100
Subject: [PATCH 125/155] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index eff58e558..81c93be78 100644
--- a/README.md
+++ b/README.md
@@ -58,7 +58,7 @@ python videoeditor.py
```
## Screenshots
-
+
From 19a5cd85583d5f38edc2dfa4437818d0407c4d8c Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 10:08:29 +1100
Subject: [PATCH 126/155] allow region selection resizing, remove all region
context menu options showing up for single regions and vice versa
---
videoeditor.py | 105 +++++++++++++++++++++++++++++++++++++++++--------
1 file changed, 88 insertions(+), 17 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index 1c7466957..4a124b785 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -134,6 +134,10 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.resize_edge = None
self.resize_start_pos = QPoint()
+ self.resizing_selection_region = None
+ self.resize_selection_edge = None
+ self.resize_selection_start_values = None
+
self.highlighted_track_info = None
self.highlighted_ghost_track_info = None
self.add_video_track_btn_rect = QRect()
@@ -522,6 +526,29 @@ def mousePressEvent(self, event: QMouseEvent):
self.resizing_clip = None
self.resize_edge = None
self.drag_original_clip_states.clear()
+ self.resizing_selection_region = None
+ self.resize_selection_edge = None
+ self.resize_selection_start_values = None
+
+ for region in self.selection_regions:
+ if not region: continue
+ x_start = self.sec_to_x(region[0])
+ x_end = self.sec_to_x(region[1])
+ if event.pos().y() > self.TIMESCALE_HEIGHT:
+ if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH:
+ self.resizing_selection_region = region
+ self.resize_selection_edge = 'left'
+ break
+ elif abs(event.pos().x() - x_end) < self.RESIZE_HANDLE_WIDTH:
+ self.resizing_selection_region = region
+ self.resize_selection_edge = 'right'
+ break
+
+ if self.resizing_selection_region:
+ self.resize_selection_start_values = tuple(self.resizing_selection_region)
+ self.drag_start_pos = event.pos()
+ self.update()
+ return
for clip in reversed(self.timeline.clips):
clip_rect = self.get_clip_rect(clip)
@@ -595,6 +622,27 @@ def mousePressEvent(self, event: QMouseEvent):
self.update()
def mouseMoveEvent(self, event: QMouseEvent):
+ if self.resizing_selection_region:
+ current_sec = max(0, self.x_to_sec(event.pos().x()))
+ original_start, original_end = self.resize_selection_start_values
+
+ if self.resize_selection_edge == 'left':
+ new_start = current_sec
+ new_end = original_end
+ else: # right
+ new_start = original_start
+ new_end = current_sec
+
+ self.resizing_selection_region[0] = min(new_start, new_end)
+ self.resizing_selection_region[1] = max(new_start, new_end)
+
+ if (self.resize_selection_edge == 'left' and new_start > new_end) or \
+ (self.resize_selection_edge == 'right' and new_end < new_start):
+ self.resize_selection_edge = 'right' if self.resize_selection_edge == 'left' else 'left'
+ self.resize_selection_start_values = (original_end, original_start)
+
+ self.update()
+ return
if self.resizing_clip:
linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None)
delta_x = event.pos().x() - self.resize_start_pos.x()
@@ -686,6 +734,15 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.setCursor(Qt.CursorShape.SizeHorCursor)
cursor_set = True
+ if not cursor_set and is_in_track_area:
+ for region in self.selection_regions:
+ x_start = self.sec_to_x(region[0])
+ x_end = self.sec_to_x(region[1])
+ if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH or \
+ abs(event.pos().x() - x_end) < self.RESIZE_HANDLE_WIDTH:
+ self.setCursor(Qt.CursorShape.SizeHorCursor)
+ cursor_set = True
+ break
if not cursor_set:
for clip in self.timeline.clips:
clip_rect = self.get_clip_rect(clip)
@@ -786,6 +843,12 @@ def mouseMoveEvent(self, event: QMouseEvent):
def mouseReleaseEvent(self, event: QMouseEvent):
if event.button() == Qt.MouseButton.LeftButton:
+ if self.resizing_selection_region:
+ self.resizing_selection_region = None
+ self.resize_selection_edge = None
+ self.resize_selection_start_values = None
+ self.update()
+ return
if self.resizing_clip:
new_state = self.window()._get_current_timeline_state()
command = TimelineStateChangeCommand("Resize Clip", self.timeline, *self.drag_start_state, *new_state)
@@ -1074,23 +1137,31 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'):
region_at_pos = self.get_region_at_pos(event.pos())
if region_at_pos:
- split_this_action = menu.addAction("Split This Region")
- split_all_action = menu.addAction("Split All Regions")
- join_this_action = menu.addAction("Join This Region")
- join_all_action = menu.addAction("Join All Regions")
- delete_this_action = menu.addAction("Delete This Region")
- delete_all_action = menu.addAction("Delete All Regions")
- menu.addSeparator()
- clear_this_action = menu.addAction("Clear This Region")
- clear_all_action = menu.addAction("Clear All Regions")
- split_this_action.triggered.connect(lambda: self.split_region_requested.emit(region_at_pos))
- split_all_action.triggered.connect(lambda: self.split_all_regions_requested.emit(self.selection_regions))
- join_this_action.triggered.connect(lambda: self.join_region_requested.emit(region_at_pos))
- join_all_action.triggered.connect(lambda: self.join_all_regions_requested.emit(self.selection_regions))
- delete_this_action.triggered.connect(lambda: self.delete_region_requested.emit(region_at_pos))
- delete_all_action.triggered.connect(lambda: self.delete_all_regions_requested.emit(self.selection_regions))
- clear_this_action.triggered.connect(lambda: self.clear_region(region_at_pos))
- clear_all_action.triggered.connect(self.clear_all_regions)
+ num_regions = len(self.selection_regions)
+
+ if num_regions == 1:
+ split_this_action = menu.addAction("Split Region")
+ join_this_action = menu.addAction("Join Region")
+ delete_this_action = menu.addAction("Delete Region")
+ menu.addSeparator()
+ clear_this_action = menu.addAction("Clear Region")
+
+ split_this_action.triggered.connect(lambda: self.split_region_requested.emit(region_at_pos))
+ join_this_action.triggered.connect(lambda: self.join_region_requested.emit(region_at_pos))
+ delete_this_action.triggered.connect(lambda: self.delete_region_requested.emit(region_at_pos))
+ clear_this_action.triggered.connect(lambda: self.clear_region(region_at_pos))
+
+ elif num_regions > 1:
+ split_all_action = menu.addAction("Split All Regions")
+ join_all_action = menu.addAction("Join All Regions")
+ delete_all_action = menu.addAction("Delete All Regions")
+ menu.addSeparator()
+ clear_all_action = menu.addAction("Clear All Regions")
+
+ split_all_action.triggered.connect(lambda: self.split_all_regions_requested.emit(self.selection_regions))
+ join_all_action.triggered.connect(lambda: self.join_all_regions_requested.emit(self.selection_regions))
+ delete_all_action.triggered.connect(lambda: self.delete_all_regions_requested.emit(self.selection_regions))
+ clear_all_action.triggered.connect(self.clear_all_regions)
clip_at_pos = None
for clip in self.timeline.clips:
From e45796c6b977176293cbf44178d0c517bab74863 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 13:32:45 +1100
Subject: [PATCH 127/155] add preview video scaling
---
videoeditor.py | 186 ++++++++++++++++++++++++++++++++++---------------
1 file changed, 131 insertions(+), 55 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index 4a124b785..15e5f0d7d 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -1363,9 +1363,15 @@ def __init__(self, project_to_load=None):
self.plugin_manager = PluginManager(self, wgp)
self.plugin_manager.discover_and_load_plugins()
+ # Project settings with defaults
self.project_fps = 25.0
self.project_width = 1280
self.project_height = 720
+
+ # Preview settings
+ self.scale_to_fit = True
+ self.current_preview_pixmap = None
+
self.playback_timer = QTimer(self)
self.playback_process = None
self.playback_clip = None
@@ -1380,6 +1386,10 @@ def __init__(self, project_to_load=None):
if not self.settings_file_was_loaded: self._save_settings()
if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load))
+ def resizeEvent(self, event):
+ super().resizeEvent(event)
+ self._update_preview_display()
+
def _get_current_timeline_state(self):
return (
copy.deepcopy(self.timeline.clips),
@@ -1395,12 +1405,18 @@ def _setup_ui(self):
self.splitter = QSplitter(Qt.Orientation.Vertical)
+ # --- PREVIEW WIDGET SETUP (CHANGED) ---
+ self.preview_scroll_area = QScrollArea()
+ self.preview_scroll_area.setWidgetResizable(False) # Important for 1:1 scaling
+ self.preview_scroll_area.setStyleSheet("background-color: black; border: 0px;")
+
self.preview_widget = QLabel()
self.preview_widget.setAlignment(Qt.AlignmentFlag.AlignCenter)
self.preview_widget.setMinimumSize(640, 360)
- self.preview_widget.setFrameShape(QFrame.Shape.Box)
- self.preview_widget.setStyleSheet("background-color: black; color: white;")
- self.splitter.addWidget(self.preview_widget)
+ self.preview_widget.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu)
+
+ self.preview_scroll_area.setWidget(self.preview_widget)
+ self.splitter.addWidget(self.preview_scroll_area)
self.timeline_widget = TimelineWidget(self.timeline, self.settings, self.project_fps, self)
self.timeline_scroll_area = QScrollArea()
@@ -1412,6 +1428,9 @@ def _setup_ui(self):
self.timeline_scroll_area.setMinimumHeight(250)
self.splitter.addWidget(self.timeline_scroll_area)
+ self.splitter.setStretchFactor(0, 1)
+ self.splitter.setStretchFactor(1, 0)
+
container_widget = QWidget()
main_layout = QVBoxLayout(container_widget)
main_layout.setContentsMargins(0,0,0,0)
@@ -1445,7 +1464,7 @@ def _setup_ui(self):
self.setCentralWidget(container_widget)
self.managed_widgets = {
- 'preview': {'widget': self.preview_widget, 'name': 'Video Preview', 'action': None},
+ 'preview': {'widget': self.preview_scroll_area, 'name': 'Video Preview', 'action': None},
'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None},
'project_media': {'widget': self.media_dock, 'name': 'Project Media', 'action': None}
}
@@ -1459,6 +1478,7 @@ def _setup_ui(self):
def _connect_signals(self):
self.splitter.splitterMoved.connect(self.on_splitter_moved)
+ self.preview_widget.customContextMenuRequested.connect(self._show_preview_context_menu)
self.timeline_widget.split_requested.connect(self.split_clip_at_playhead)
self.timeline_widget.delete_clip_requested.connect(self.delete_clip)
@@ -1487,6 +1507,18 @@ def _connect_signals(self):
self.undo_stack.history_changed.connect(self.update_undo_redo_actions)
self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo)
+ def _show_preview_context_menu(self, pos):
+ menu = QMenu(self)
+ scale_action = QAction("Scale to Fit", self, checkable=True)
+ scale_action.setChecked(self.scale_to_fit)
+ scale_action.toggled.connect(self._toggle_scale_to_fit)
+ menu.addAction(scale_action)
+ menu.exec(self.preview_widget.mapToGlobal(pos))
+
+ def _toggle_scale_to_fit(self, checked):
+ self.scale_to_fit = checked
+ self._update_preview_display()
+
def on_timeline_changed_by_undo(self):
self.prune_empty_tracks()
self.timeline_widget.update()
@@ -1644,7 +1676,7 @@ def _create_menu_bar(self):
self.windows_menu = menu_bar.addMenu("&Windows")
for key, data in self.managed_widgets.items():
- if data['widget'] is self.preview_widget: continue
+ if data['widget'] is self.preview_scroll_area: continue
action = QAction(data['name'], self, checkable=True)
if hasattr(data['widget'], 'visibilityChanged'):
action.toggled.connect(data['widget'].setVisible)
@@ -1700,9 +1732,12 @@ def _get_media_properties(self, file_path):
media_info = {}
if file_ext in ['.png', '.jpg', '.jpeg']:
+ img = QImage(file_path)
media_info['media_type'] = 'image'
media_info['duration'] = 5.0
media_info['has_audio'] = False
+ media_info['width'] = img.width()
+ media_info['height'] = img.height()
else:
probe = ffmpeg.probe(file_path)
video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
@@ -1712,6 +1747,11 @@ def _get_media_properties(self, file_path):
media_info['media_type'] = 'video'
media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0)))
media_info['has_audio'] = audio_stream is not None
+ media_info['width'] = int(video_stream['width'])
+ media_info['height'] = int(video_stream['height'])
+ if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0':
+ num, den = map(int, video_stream['r_frame_rate'].split('/'))
+ if den > 0: media_info['fps'] = num / den
elif audio_stream:
media_info['media_type'] = 'audio'
media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0)))
@@ -1724,20 +1764,37 @@ def _get_media_properties(self, file_path):
print(f"Failed to probe file {os.path.basename(file_path)}: {e}")
return None
- def _set_project_properties_from_clip(self, source_path):
+ def _update_project_properties_from_clip(self, source_path):
try:
- probe = ffmpeg.probe(source_path)
- video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None)
- if video_stream:
- self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height'])
- if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0':
- num, den = map(int, video_stream['r_frame_rate'].split('/'))
- if den > 0:
- self.project_fps = num / den
- self.timeline_widget.set_project_fps(self.project_fps)
- print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS")
- return True
- except Exception as e: print(f"Could not probe for project properties: {e}")
+ media_info = self._get_media_properties(source_path)
+ if not media_info or media_info['media_type'] not in ['video', 'image']:
+ return False
+
+ new_w = media_info.get('width')
+ new_h = media_info.get('height')
+ new_fps = media_info.get('fps')
+
+ if not new_w or not new_h:
+ return False
+
+ is_first_video = not any(c.media_type in ['video', 'image'] for c in self.timeline.clips if c.source_path != source_path)
+
+ if is_first_video:
+ self.project_width = new_w
+ self.project_height = new_h
+ if new_fps: self.project_fps = new_fps
+ self.timeline_widget.set_project_fps(self.project_fps)
+ print(f"Project properties set from first clip: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS")
+ else:
+ current_area = self.project_width * self.project_height
+ new_area = new_w * new_h
+ if new_area > current_area:
+ self.project_width = new_w
+ self.project_height = new_h
+ print(f"Project resolution updated to: {self.project_width}x{self.project_height}")
+ return True
+ except Exception as e:
+ print(f"Could not probe for project properties: {e}")
return False
def _probe_for_drag(self, file_path):
@@ -1774,9 +1831,8 @@ def get_frame_data_at_time(self, time_sec):
return (None, 0, 0)
def get_frame_at_time(self, time_sec):
- black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
clip_at_time = self._get_topmost_video_clip_at(time_sec)
- if not clip_at_time: return black_pixmap
+ if not clip_at_time: return None
try:
w, h = self.project_width, self.project_height
if clip_at_time.media_type == 'image':
@@ -1797,16 +1853,36 @@ def get_frame_at_time(self, time_sec):
image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888)
return QPixmap.fromImage(image)
- except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap
+ except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return None
def seek_preview(self, time_sec):
self._stop_playback_stream()
self.timeline_widget.playhead_pos_sec = time_sec
self.timeline_widget.update()
- frame_pixmap = self.get_frame_at_time(time_sec)
- if frame_pixmap:
- scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
+ self.current_preview_pixmap = self.get_frame_at_time(time_sec)
+ self._update_preview_display()
+
+ def _update_preview_display(self):
+ pixmap_to_show = self.current_preview_pixmap
+ if not pixmap_to_show:
+ pixmap_to_show = QPixmap(self.project_width, self.project_height)
+ pixmap_to_show.fill(QColor("black"))
+
+ if self.scale_to_fit:
+ # When fitting, we want the label to resize with the scroll area
+ self.preview_scroll_area.setWidgetResizable(True)
+ scaled_pixmap = pixmap_to_show.scaled(
+ self.preview_scroll_area.viewport().size(),
+ Qt.AspectRatioMode.KeepAspectRatio,
+ Qt.TransformationMode.SmoothTransformation
+ )
self.preview_widget.setPixmap(scaled_pixmap)
+ else:
+ # For 1:1, the label takes the size of the pixmap, and the scroll area handles overflow
+ self.preview_scroll_area.setWidgetResizable(False)
+ self.preview_widget.setPixmap(pixmap_to_show)
+ self.preview_widget.adjustSize()
+
def toggle_playback(self):
if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play")
@@ -1835,19 +1911,16 @@ def advance_playback_frame(self):
if not clip_at_new_time:
self._stop_playback_stream()
- black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black"))
- scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
- self.preview_widget.setPixmap(scaled_pixmap)
+ self.current_preview_pixmap = None
+ self._update_preview_display()
return
if clip_at_new_time.media_type == 'image':
if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
self._stop_playback_stream()
self.playback_clip = clip_at_new_time
- frame_pixmap = self.get_frame_at_time(new_time)
- if frame_pixmap:
- scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
- self.preview_widget.setPixmap(scaled_pixmap)
+ self.current_preview_pixmap = self.get_frame_at_time(new_time)
+ self._update_preview_display()
return
if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
@@ -1858,9 +1931,8 @@ def advance_playback_frame(self):
frame_bytes = self.playback_process.stdout.read(frame_size)
if len(frame_bytes) == frame_size:
image = QImage(frame_bytes, self.project_width, self.project_height, QImage.Format.Format_RGB888)
- pixmap = QPixmap.fromImage(image)
- scaled_pixmap = pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
- self.preview_widget.setPixmap(scaled_pixmap)
+ self.current_preview_pixmap = QPixmap.fromImage(image)
+ self._update_preview_display()
else:
self._stop_playback_stream()
@@ -1892,13 +1964,15 @@ def _apply_loaded_settings(self):
visibility_settings = self.settings.get("window_visibility", {})
for key, data in self.managed_widgets.items():
is_visible = visibility_settings.get(key, False if key == 'project_media' else True)
- if data['widget'] is not self.preview_widget:
+ if data['widget'] is not self.preview_scroll_area:
data['widget'].setVisible(is_visible)
if data['action']: data['action'].setChecked(is_visible)
splitter_state = self.settings.get("splitter_state")
if splitter_state: self.splitter.restoreState(QByteArray.fromHex(splitter_state.encode('ascii')))
- def on_splitter_moved(self, pos, index): self.splitter_save_timer.start(500)
+ def on_splitter_moved(self, pos, index):
+ self.splitter_save_timer.start(500)
+ self._update_preview_display()
def toggle_widget_visibility(self, key, checked):
if self.is_shutting_down: return
@@ -1916,6 +1990,8 @@ def new_project(self):
self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list()
self.current_project_path = None; self.stop_playback()
self.project_fps = 25.0
+ self.project_width = 1280
+ self.project_height = 720
self.timeline_widget.set_project_fps(self.project_fps)
self.timeline_widget.update()
self.undo_stack = UndoStack()
@@ -1929,7 +2005,13 @@ def save_project_as(self):
project_data = {
"media_pool": self.media_pool,
"clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips],
- "settings": {"num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks}
+ "settings": {
+ "num_video_tracks": self.timeline.num_video_tracks,
+ "num_audio_tracks": self.timeline.num_audio_tracks,
+ "project_width": self.project_width,
+ "project_height": self.project_height,
+ "project_fps": self.project_fps
+ }
}
try:
with open(path, "w") as f: json.dump(project_data, f, indent=4)
@@ -1946,6 +2028,14 @@ def _load_project_from_path(self, path):
with open(path, "r") as f: project_data = json.load(f)
self.new_project()
+ project_settings = project_data.get("settings", {})
+ self.timeline.num_video_tracks = project_settings.get("num_video_tracks", 1)
+ self.timeline.num_audio_tracks = project_settings.get("num_audio_tracks", 1)
+ self.project_width = project_settings.get("project_width", 1280)
+ self.project_height = project_settings.get("project_height", 720)
+ self.project_fps = project_settings.get("project_fps", 25.0)
+ self.timeline_widget.set_project_fps(self.project_fps)
+
self.media_pool = project_data.get("media_pool", [])
for p in self.media_pool: self._add_media_to_pool(p)
@@ -1962,14 +2052,7 @@ def _load_project_from_path(self, path):
self.timeline.add_clip(TimelineClip(**clip_data))
- project_settings = project_data.get("settings", {})
- self.timeline.num_video_tracks = project_settings.get("num_video_tracks", 1)
- self.timeline.num_audio_tracks = project_settings.get("num_audio_tracks", 1)
-
self.current_project_path = path
- video_clips = [c for c in self.timeline.clips if c.media_type == 'video']
- if video_clips:
- self._set_project_properties_from_clip(video_clips[0].source_path)
self.prune_empty_tracks()
self.timeline_widget.update(); self.stop_playback()
self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.")
@@ -2030,20 +2113,10 @@ def _add_media_files_to_project(self, file_paths):
return []
self.media_dock.show()
-
added_files = []
- first_video_added = any(c.media_type == 'video' for c in self.timeline.clips)
for file_path in file_paths:
- if not self.timeline.clips and not self.media_pool and not first_video_added:
- ext = os.path.splitext(file_path)[1].lower()
- if ext not in ['.png', '.jpg', '.jpeg', '.mp3', '.wav']:
- if self._set_project_properties_from_clip(file_path):
- first_video_added = True
- else:
- self.status_label.setText("Error: Could not determine video properties from file.")
- continue
-
+ self._update_project_properties_from_clip(file_path)
if self._add_media_to_pool(file_path):
added_files.append(file_path)
@@ -2118,6 +2191,9 @@ def add_media_files(self):
self._add_media_files_to_project(file_paths)
def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None):
+ if media_type in ['video', 'image']:
+ self._update_project_properties_from_clip(source_path)
+
old_state = self._get_current_timeline_state()
group_id = str(uuid.uuid4())
From e6fd0aa99fe56d2381fc71df997d7db23eab15e6 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 14:02:38 +1100
Subject: [PATCH 128/155] fix scroll centering, overlapping track labels
---
videoeditor.py | 120 ++++++++++++++++++++++++++++---------------------
1 file changed, 70 insertions(+), 50 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index 15e5f0d7d..a05ff8506 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -106,7 +106,10 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.timeline = timeline_model
self.settings = settings
self.playhead_pos_sec = 0.0
- self.scroll_area = None
+ self.view_start_sec = 0.0
+ self.panning = False
+ self.pan_start_pos = QPoint()
+ self.pan_start_view_sec = 0.0
self.pixels_per_second = 50.0
self.max_pixels_per_second = 1.0
@@ -165,19 +168,19 @@ def set_project_fps(self, fps):
self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second)
self.update()
- def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.pixels_per_second)
- def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else 0.0
+ def sec_to_x(self, sec): return self.HEADER_WIDTH + int((sec - self.view_start_sec) * self.pixels_per_second)
+ def x_to_sec(self, x): return self.view_start_sec + float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else self.view_start_sec
def paintEvent(self, event):
painter = QPainter(self)
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
painter.fillRect(self.rect(), QColor("#333"))
- h_offset = 0
- if self.scroll_area and self.scroll_area.horizontalScrollBar():
- h_offset = self.scroll_area.horizontalScrollBar().value()
+ self.draw_headers(painter)
- self.draw_headers(painter, h_offset)
+ painter.save()
+ painter.setClipRect(self.HEADER_WIDTH, 0, self.width() - self.HEADER_WIDTH, self.height())
+
self.draw_timescale(painter)
self.draw_tracks_and_clips(painter)
self.draw_selections(painter)
@@ -201,17 +204,19 @@ def paintEvent(self, event):
painter.drawRect(self.hover_preview_audio_rect)
self.draw_playhead(painter)
+
+ painter.restore()
- total_width = self.sec_to_x(self.timeline.get_total_duration()) + 200
total_height = self.calculate_total_height()
- self.setMinimumSize(max(self.parent().width(), total_width), total_height)
+ if self.minimumHeight() != total_height:
+ self.setMinimumHeight(total_height)
def calculate_total_height(self):
video_tracks_height = (self.timeline.num_video_tracks + 1) * self.TRACK_HEIGHT
audio_tracks_height = (self.timeline.num_audio_tracks + 1) * self.TRACK_HEIGHT
return self.TIMESCALE_HEIGHT + video_tracks_height + self.AUDIO_TRACKS_SEPARATOR_Y + audio_tracks_height + 20
- def draw_headers(self, painter, h_offset):
+ def draw_headers(self, painter):
painter.save()
painter.setPen(QColor("#AAA"))
header_font = QFont("Arial", 9, QFont.Weight.Bold)
@@ -220,9 +225,9 @@ def draw_headers(self, painter, h_offset):
y_cursor = self.TIMESCALE_HEIGHT
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a"))
- painter.drawRect(rect.translated(h_offset, 0))
- self.add_video_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
+ painter.fillRect(rect, QColor("#3a3a3a"))
+ painter.drawRect(rect)
+ self.add_video_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
painter.setFont(button_font)
painter.fillRect(self.add_video_track_btn_rect, QColor("#454"))
painter.drawText(self.add_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)")
@@ -232,13 +237,13 @@ def draw_headers(self, painter, h_offset):
for i in range(self.timeline.num_video_tracks):
track_number = self.timeline.num_video_tracks - i
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect.translated(h_offset, 0), QColor("#444"))
- painter.drawRect(rect.translated(h_offset, 0))
+ painter.fillRect(rect, QColor("#444"))
+ painter.drawRect(rect)
painter.setFont(header_font)
- painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Video {track_number}")
+ painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Video {track_number}")
if track_number == self.timeline.num_video_tracks and self.timeline.num_video_tracks > 1:
- self.remove_video_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20)
+ self.remove_video_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20)
painter.setFont(button_font)
painter.fillRect(self.remove_video_track_btn_rect, QColor("#833"))
painter.drawText(self.remove_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-")
@@ -250,22 +255,22 @@ def draw_headers(self, painter, h_offset):
for i in range(self.timeline.num_audio_tracks):
track_number = i + 1
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect.translated(h_offset, 0), QColor("#444"))
- painter.drawRect(rect.translated(h_offset, 0))
+ painter.fillRect(rect, QColor("#444"))
+ painter.drawRect(rect)
painter.setFont(header_font)
- painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}")
+ painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}")
if track_number == self.timeline.num_audio_tracks and self.timeline.num_audio_tracks > 1:
- self.remove_audio_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20)
+ self.remove_audio_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20)
painter.setFont(button_font)
painter.fillRect(self.remove_audio_track_btn_rect, QColor("#833"))
painter.drawText(self.remove_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-")
y_cursor += self.TRACK_HEIGHT
rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT)
- painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a"))
- painter.drawRect(rect.translated(h_offset, 0))
- self.add_audio_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
+ painter.fillRect(rect, QColor("#3a3a3a"))
+ painter.drawRect(rect)
+ self.add_audio_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22)
painter.setFont(button_font)
painter.fillRect(self.add_audio_track_btn_rect, QColor("#454"))
painter.drawText(self.add_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)")
@@ -476,39 +481,46 @@ def get_region_at_pos(self, pos: QPoint):
return None
def wheelEvent(self, event: QMouseEvent):
- if not self.scroll_area or event.position().x() < self.HEADER_WIDTH:
+ if event.position().x() < self.HEADER_WIDTH:
event.ignore()
return
- scrollbar = self.scroll_area.horizontalScrollBar()
- mouse_x_abs = event.position().x()
-
- time_at_cursor = self.x_to_sec(mouse_x_abs)
+ mouse_x = event.position().x()
+ time_at_cursor = self.x_to_sec(mouse_x)
delta = event.angleDelta().y()
zoom_factor = 1.15
old_pps = self.pixels_per_second
- if delta > 0: new_pps = old_pps * zoom_factor
- else: new_pps = old_pps / zoom_factor
-
+ if delta > 0:
+ new_pps = old_pps * zoom_factor
+ else:
+ new_pps = old_pps / zoom_factor
+
min_pps = 1 / (3600 * 10)
new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second))
-
+
if abs(new_pps - old_pps) < 1e-9:
return
-
+
self.pixels_per_second = new_pps
- self.update()
- new_mouse_x_abs = self.sec_to_x(time_at_cursor)
- shift_amount = new_mouse_x_abs - mouse_x_abs
- new_scroll_value = scrollbar.value() + shift_amount
- scrollbar.setValue(int(new_scroll_value))
+ new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / self.pixels_per_second
+ self.view_start_sec = max(0.0, new_view_start_sec)
+
+ self.update()
event.accept()
def mousePressEvent(self, event: QMouseEvent):
- if event.pos().x() < self.HEADER_WIDTH + self.scroll_area.horizontalScrollBar().value():
+ if event.button() == Qt.MouseButton.MiddleButton:
+ self.panning = True
+ self.pan_start_pos = event.pos()
+ self.pan_start_view_sec = self.view_start_sec
+ self.setCursor(Qt.CursorShape.ClosedHandCursor)
+ event.accept()
+ return
+
+ if event.pos().x() < self.HEADER_WIDTH:
if self.add_video_track_btn_rect.contains(event.pos()): self.add_track.emit('video')
elif self.remove_video_track_btn_rect.contains(event.pos()): self.remove_track.emit('video')
elif self.add_audio_track_btn_rect.contains(event.pos()): self.add_track.emit('audio')
@@ -622,6 +634,14 @@ def mousePressEvent(self, event: QMouseEvent):
self.update()
def mouseMoveEvent(self, event: QMouseEvent):
+ if self.panning:
+ delta_x = event.pos().x() - self.pan_start_pos.x()
+ time_delta = delta_x / self.pixels_per_second
+ new_view_start = self.pan_start_view_sec - time_delta
+ self.view_start_sec = max(0.0, new_view_start)
+ self.update()
+ return
+
if self.resizing_selection_region:
current_sec = max(0, self.x_to_sec(event.pos().x()))
original_start, original_end = self.resize_selection_start_values
@@ -842,6 +862,12 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.update()
def mouseReleaseEvent(self, event: QMouseEvent):
+ if event.button() == Qt.MouseButton.MiddleButton and self.panning:
+ self.panning = False
+ self.unsetCursor()
+ event.accept()
+ return
+
if event.button() == Qt.MouseButton.LeftButton:
if self.resizing_selection_region:
self.resizing_selection_region = None
@@ -1419,14 +1445,8 @@ def _setup_ui(self):
self.splitter.addWidget(self.preview_scroll_area)
self.timeline_widget = TimelineWidget(self.timeline, self.settings, self.project_fps, self)
- self.timeline_scroll_area = QScrollArea()
- self.timeline_widget.scroll_area = self.timeline_scroll_area
- self.timeline_scroll_area.setWidgetResizable(False)
- self.timeline_widget.setMinimumWidth(2000)
- self.timeline_scroll_area.setWidget(self.timeline_widget)
- self.timeline_scroll_area.setFrameShape(QFrame.Shape.NoFrame)
- self.timeline_scroll_area.setMinimumHeight(250)
- self.splitter.addWidget(self.timeline_scroll_area)
+ self.timeline_widget.setMinimumHeight(250)
+ self.splitter.addWidget(self.timeline_widget)
self.splitter.setStretchFactor(0, 1)
self.splitter.setStretchFactor(1, 0)
@@ -1465,7 +1485,7 @@ def _setup_ui(self):
self.managed_widgets = {
'preview': {'widget': self.preview_scroll_area, 'name': 'Video Preview', 'action': None},
- 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None},
+ 'timeline': {'widget': self.timeline_widget, 'name': 'Timeline', 'action': None},
'project_media': {'widget': self.media_dock, 'name': 'Project Media', 'action': None}
}
self.plugin_menu_actions = {}
@@ -1676,7 +1696,7 @@ def _create_menu_bar(self):
self.windows_menu = menu_bar.addMenu("&Windows")
for key, data in self.managed_widgets.items():
- if data['widget'] is self.preview_scroll_area: continue
+ if data['widget'] is self.preview_scroll_area or data['widget'] is self.timeline_widget: continue
action = QAction(data['name'], self, checkable=True)
if hasattr(data['widget'], 'visibilityChanged'):
action.toggled.connect(data['widget'].setVisible)
From 46d2c6c4b45badcb10870eecdb7208dfeaee9c3c Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 14:09:17 +1100
Subject: [PATCH 129/155] fix 1st frame being rendered blank when adding video
to timeline
---
videoeditor.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/videoeditor.py b/videoeditor.py
index a05ff8506..885a58bef 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -1542,6 +1542,7 @@ def _toggle_scale_to_fit(self, checked):
def on_timeline_changed_by_undo(self):
self.prune_empty_tracks()
self.timeline_widget.update()
+ self.seek_preview(self.timeline_widget.playhead_pos_sec)
self.status_label.setText("Operation undone/redone.")
def update_undo_redo_actions(self):
From 0bd6a8d6c83ac2972c8eefd48475abf5f3c6a394 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 14:27:25 +1100
Subject: [PATCH 130/155] add zoom to 0 on timescale
---
videoeditor.py | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index 885a58bef..a63d2ac99 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -481,13 +481,6 @@ def get_region_at_pos(self, pos: QPoint):
return None
def wheelEvent(self, event: QMouseEvent):
- if event.position().x() < self.HEADER_WIDTH:
- event.ignore()
- return
-
- mouse_x = event.position().x()
- time_at_cursor = self.x_to_sec(mouse_x)
-
delta = event.angleDelta().y()
zoom_factor = 1.15
old_pps = self.pixels_per_second
@@ -503,9 +496,16 @@ def wheelEvent(self, event: QMouseEvent):
if abs(new_pps - old_pps) < 1e-9:
return
- self.pixels_per_second = new_pps
+ if event.position().x() < self.HEADER_WIDTH:
+ # Zoom centered on time 0. Keep screen position of time 0 constant.
+ new_view_start_sec = self.view_start_sec * (old_pps / new_pps)
+ else:
+ # Zoom centered on mouse cursor.
+ mouse_x = event.position().x()
+ time_at_cursor = self.x_to_sec(mouse_x)
+ new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps
- new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / self.pixels_per_second
+ self.pixels_per_second = new_pps
self.view_start_sec = max(0.0, new_view_start_sec)
self.update()
From 0ca8ba19cf7c40e275fca2e420c0cb18eb9a5c8f Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 14:33:07 +1100
Subject: [PATCH 131/155] fix track labels overlaying 0s
---
videoeditor.py | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index a63d2ac99..76a4ab495 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -349,7 +349,10 @@ def draw_ticks(interval, height):
painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT)
label = self._format_timecode(t_sec)
label_width = font_metrics.horizontalAdvance(label)
- painter.drawText(x - label_width // 2, self.TIMESCALE_HEIGHT - 14, label)
+ label_x = x - label_width // 2
+ if label_x < self.HEADER_WIDTH:
+ label_x = self.HEADER_WIDTH
+ painter.drawText(label_x, self.TIMESCALE_HEIGHT - 14, label)
painter.restore()
@@ -497,10 +500,8 @@ def wheelEvent(self, event: QMouseEvent):
return
if event.position().x() < self.HEADER_WIDTH:
- # Zoom centered on time 0. Keep screen position of time 0 constant.
new_view_start_sec = self.view_start_sec * (old_pps / new_pps)
else:
- # Zoom centered on mouse cursor.
mouse_x = event.position().x()
time_at_cursor = self.x_to_sec(mouse_x)
new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps
From eef79e40fb20bd30f40271a2d51f28a0197527e8 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 14:53:22 +1100
Subject: [PATCH 132/155] fix add track button not working
---
videoeditor.py | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index 76a4ab495..b4cdb2164 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -1626,12 +1626,18 @@ def add_track(self, track_type):
self.timeline.num_video_tracks += 1
elif track_type == 'audio':
self.timeline.num_audio_tracks += 1
+ else:
+ return
+
new_state = self._get_current_timeline_state()
-
command = TimelineStateChangeCommand(f"Add {track_type.capitalize()} Track", self.timeline, *old_state, *new_state)
+
+ self.undo_stack.blockSignals(True)
self.undo_stack.push(command)
- command.undo()
- self.undo_stack.push(command)
+ self.undo_stack.blockSignals(False)
+
+ self.update_undo_redo_actions()
+ self.timeline_widget.update()
def remove_track(self, track_type):
old_state = self._get_current_timeline_state()
From 088940b0e526c1335b75b87a158ef52529fa7751 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 15:10:34 +1100
Subject: [PATCH 133/155] snap timehead to single frames
---
videoeditor.py | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index b4cdb2164..6abb4091b 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -473,6 +473,13 @@ def y_to_track_info(self, y):
return None
+ def _snap_time_if_needed(self, time_sec):
+ frame_duration = 1.0 / self.project_fps
+ if frame_duration > 0 and frame_duration * self.pixels_per_second > 4:
+ frame_number = round(time_sec / frame_duration)
+ return frame_number * frame_duration
+ return time_sec
+
def get_region_at_pos(self, pos: QPoint):
if pos.y() <= self.TIMESCALE_HEIGHT or pos.x() <= self.HEADER_WIDTH:
return None
@@ -628,7 +635,8 @@ def mousePressEvent(self, event: QMouseEvent):
self.selection_drag_start_sec = self.x_to_sec(event.pos().x())
self.selection_regions.append([self.selection_drag_start_sec, self.selection_drag_start_sec])
elif is_on_timescale:
- self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x()))
+ time_sec = max(0, self.x_to_sec(event.pos().x()))
+ self.playhead_pos_sec = self._snap_time_if_needed(time_sec)
self.playhead_moved.emit(self.playhead_pos_sec)
self.dragging_playhead = True
@@ -798,7 +806,8 @@ def mouseMoveEvent(self, event: QMouseEvent):
return
if self.dragging_playhead:
- self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x()))
+ time_sec = max(0, self.x_to_sec(event.pos().x()))
+ self.playhead_pos_sec = self._snap_time_if_needed(time_sec)
self.playhead_moved.emit(self.playhead_pos_sec)
self.update()
elif self.dragging_clip:
From cc553e2e5a26f0e2fe35c3bc8645a5dbe774eb43 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 15:49:55 +1100
Subject: [PATCH 134/155] automatically set resolution from project resolution
or clip resolution depending on mode
---
plugins/wan2gp/main.py | 82 ++++++++++++++++++++++++++++++++++++++----
1 file changed, 75 insertions(+), 7 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 0b925227c..568c31fb9 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -1339,6 +1339,55 @@ def _on_resolution_group_changed(self):
self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution))
self.widgets['resolution'].blockSignals(False)
+ def set_resolution_from_target(self, target_w, target_h):
+ if not self.full_resolution_choices:
+ print("Resolution choices not available for AI resolution matching.")
+ return
+
+ target_pixels = target_w * target_h
+ target_ar = target_w / target_h if target_h > 0 else 1.0
+
+ best_res_value = None
+ min_dist = float('inf')
+
+ for label, res_value in self.full_resolution_choices:
+ try:
+ w_str, h_str = res_value.split('x')
+ w, h = int(w_str), int(h_str)
+ except (ValueError, AttributeError):
+ continue
+
+ pixels = w * h
+ ar = w / h if h > 0 else 1.0
+
+ pixel_dist = abs(target_pixels - pixels) / target_pixels
+ ar_dist = abs(target_ar - ar) / target_ar
+
+ dist = pixel_dist * 0.8 + ar_dist * 0.2
+
+ if dist < min_dist:
+ min_dist = dist
+ best_res_value = res_value
+
+ if best_res_value:
+ best_group = self.wgp.categorize_resolution(best_res_value)
+
+ group_combo = self.widgets['resolution_group']
+ group_combo.blockSignals(True)
+ group_index = group_combo.findText(best_group)
+ if group_index != -1:
+ group_combo.setCurrentIndex(group_index)
+ group_combo.blockSignals(False)
+
+ self._on_resolution_group_changed()
+
+ res_combo = self.widgets['resolution']
+ res_index = res_combo.findData(best_res_value)
+ if res_index != -1:
+ res_combo.setCurrentIndex(res_index)
+ else:
+ print(f"Warning: Could not find resolution '{best_res_value}' in dropdown after group change.")
+
def collect_inputs(self):
full_inputs = self.wgp.get_current_model_settings(self.state).copy()
full_inputs['lset_name'] = ""
@@ -1663,6 +1712,9 @@ def setup_generator_for_region(self, region, on_new_track=False):
if not start_data or not end_data:
QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.")
return
+
+ self.client_widget.set_resolution_from_target(w, h)
+
try:
self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_")
self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png")
@@ -1675,13 +1727,25 @@ def setup_generator_for_region(self, region, on_new_track=False):
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
-
- self.client_widget.widgets['video_length'].setValue(video_length_frames)
- self.client_widget.widgets['mode_s'].setChecked(True)
- self.client_widget.widgets['image_end_checkbox'].setChecked(True)
- self.client_widget.widgets['image_start'].setText(self.start_frame_path)
- self.client_widget.widgets['image_end'].setText(self.end_frame_path)
-
+ widgets = self.client_widget.widgets
+ widgets['mode_s'].blockSignals(True)
+ widgets['mode_t'].blockSignals(True)
+ widgets['mode_v'].blockSignals(True)
+ widgets['mode_l'].blockSignals(True)
+ widgets['image_end_checkbox'].blockSignals(True)
+
+ widgets['video_length'].setValue(video_length_frames)
+ widgets['mode_s'].setChecked(True)
+ widgets['image_end_checkbox'].setChecked(True)
+ widgets['image_start'].setText(self.start_frame_path)
+ widgets['image_end'].setText(self.end_frame_path)
+
+ widgets['mode_s'].blockSignals(False)
+ widgets['mode_t'].blockSignals(False)
+ widgets['mode_v'].blockSignals(False)
+ widgets['mode_l'].blockSignals(False)
+ widgets['image_end_checkbox'].blockSignals(False)
+
self.client_widget._update_input_visibility()
except Exception as e:
@@ -1707,6 +1771,10 @@ def setup_creator_for_region(self, region, on_new_track=False):
else:
print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.")
+ target_w = self.app.project_width
+ target_h = self.app.project_height
+ self.client_widget.set_resolution_from_target(target_w, target_h)
+
start_sec, end_sec = region
duration_sec = end_sec - start_sec
wgp = self.client_widget.wgp
From 80ecba27c962b8e5cda8f7a562eb5fb061f934aa Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 16:09:54 +1100
Subject: [PATCH 135/155] add join to start frame with ai and join to end frame
with ai modes
---
plugins/wan2gp/main.py | 156 ++++++++++++++++++++++++++++++++++++++---
1 file changed, 146 insertions(+), 10 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 568c31fb9..2c38d30fa 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -1681,11 +1681,23 @@ def on_timeline_context_menu(self, menu, event):
start_sec, end_sec = region
start_data, _, _ = self.app.get_frame_data_at_time(start_sec)
end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+
if start_data and end_data:
join_action = menu.addAction("Join Frames With AI")
join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False))
join_action_new_track = menu.addAction("Join Frames With AI (New Track)")
join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True))
+ elif start_data:
+ from_start_action = menu.addAction("Generate from Start Frame with AI")
+ from_start_action.triggered.connect(lambda: self.setup_generator_from_start(region, on_new_track=False))
+ from_start_action_new_track = menu.addAction("Generate from Start Frame with AI (New Track)")
+ from_start_action_new_track.triggered.connect(lambda: self.setup_generator_from_start(region, on_new_track=True))
+ elif end_data:
+ to_end_action = menu.addAction("Generate to End Frame with AI")
+ to_end_action.triggered.connect(lambda: self.setup_generator_to_end(region, on_new_track=False))
+ to_end_action_new_track = menu.addAction("Generate to End Frame with AI (New Track)")
+ to_end_action_new_track.triggered.connect(lambda: self.setup_generator_to_end(region, on_new_track=True))
+
create_action = menu.addAction("Create Frames With AI")
create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False))
create_action_new_track = menu.addAction("Create Frames With AI (New Track)")
@@ -1728,11 +1740,9 @@ def setup_generator_for_region(self, region, on_new_track=False):
fps = wgp.get_model_fps(model_type)
video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
widgets = self.client_widget.widgets
- widgets['mode_s'].blockSignals(True)
- widgets['mode_t'].blockSignals(True)
- widgets['mode_v'].blockSignals(True)
- widgets['mode_l'].blockSignals(True)
- widgets['image_end_checkbox'].blockSignals(True)
+
+ for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']:
+ widgets[w_name].blockSignals(True)
widgets['video_length'].setValue(video_length_frames)
widgets['mode_s'].setChecked(True)
@@ -1740,11 +1750,8 @@ def setup_generator_for_region(self, region, on_new_track=False):
widgets['image_start'].setText(self.start_frame_path)
widgets['image_end'].setText(self.end_frame_path)
- widgets['mode_s'].blockSignals(False)
- widgets['mode_t'].blockSignals(False)
- widgets['mode_v'].blockSignals(False)
- widgets['mode_l'].blockSignals(False)
- widgets['image_end_checkbox'].blockSignals(False)
+ for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']:
+ widgets[w_name].blockSignals(False)
self.client_widget._update_input_visibility()
@@ -1756,6 +1763,135 @@ def setup_generator_for_region(self, region, on_new_track=False):
self.dock_widget.show()
self.dock_widget.raise_()
+ def setup_generator_from_start(self, region, on_new_track=False):
+ self._reset_state()
+ self.active_region = region
+ self.insert_on_new_track = on_new_track
+
+ model_to_set = 'i2v_2_2'
+ dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
+ _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ if any(model_to_set == m[1] for m in all_models):
+ if self.client_widget.state.get('model_type') != model_to_set:
+ self.client_widget.update_model_dropdowns(model_to_set)
+ self.client_widget._on_model_changed()
+ else:
+ print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
+
+ start_sec, end_sec = region
+ start_data, w, h = self.app.get_frame_data_at_time(start_sec)
+ if not start_data:
+ QMessageBox.warning(self.app, "Frame Error", "Could not extract start frame.")
+ return
+
+ self.client_widget.set_resolution_from_target(w, h)
+
+ try:
+ self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_")
+ self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png")
+ QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path)
+
+ duration_sec = end_sec - start_sec
+ wgp = self.client_widget.wgp
+ model_type = self.client_widget.state['model_type']
+ fps = wgp.get_model_fps(model_type)
+ video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+ widgets = self.client_widget.widgets
+
+ widgets['video_length'].setValue(video_length_frames)
+
+ for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']:
+ widgets[w_name].blockSignals(True)
+
+ widgets['mode_s'].setChecked(True)
+ widgets['image_end_checkbox'].setChecked(False)
+ widgets['image_start'].setText(self.start_frame_path)
+ widgets['image_end'].clear()
+
+ for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']:
+ widgets[w_name].blockSignals(False)
+
+ self.client_widget._update_input_visibility()
+
+ except Exception as e:
+ QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame image: {e}")
+ self._cleanup_temp_dir()
+ return
+
+ self.app.status_label.setText(f"Ready to generate from frame at {start_sec:.2f}s.")
+ self.dock_widget.show()
+ self.dock_widget.raise_()
+
+ def setup_generator_to_end(self, region, on_new_track=False):
+ self._reset_state()
+ self.active_region = region
+ self.insert_on_new_track = on_new_track
+
+ model_to_set = 'i2v_2_2'
+ dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
+ _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ if any(model_to_set == m[1] for m in all_models):
+ if self.client_widget.state.get('model_type') != model_to_set:
+ self.client_widget.update_model_dropdowns(model_to_set)
+ self.client_widget._on_model_changed()
+ else:
+ print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
+
+ start_sec, end_sec = region
+ end_data, w, h = self.app.get_frame_data_at_time(end_sec)
+ if not end_data:
+ QMessageBox.warning(self.app, "Frame Error", "Could not extract end frame.")
+ return
+
+ self.client_widget.set_resolution_from_target(w, h)
+
+ try:
+ self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_")
+ self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png")
+ QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
+
+ duration_sec = end_sec - start_sec
+ wgp = self.client_widget.wgp
+ model_type = self.client_widget.state['model_type']
+ fps = wgp.get_model_fps(model_type)
+ video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+ widgets = self.client_widget.widgets
+
+ widgets['video_length'].setValue(video_length_frames)
+
+ model_def = self.client_widget.wgp.get_model_def(self.client_widget.state['model_type'])
+ allowed_modes = model_def.get("image_prompt_types_allowed", "")
+
+ if "E" not in allowed_modes:
+ QMessageBox.warning(self.app, "Model Incompatible", "The current model does not support generating to an end frame.")
+ return
+
+ if "S" not in allowed_modes:
+ QMessageBox.warning(self.app, "Model Incompatible", "The current model supports end frames, but not in a way compatible with this UI feature (missing 'Start with Image' mode).")
+ return
+
+ for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']:
+ widgets[w_name].blockSignals(True)
+
+ widgets['mode_s'].setChecked(True)
+ widgets['image_end_checkbox'].setChecked(True)
+ widgets['image_start'].clear()
+ widgets['image_end'].setText(self.end_frame_path)
+
+ for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']:
+ widgets[w_name].blockSignals(False)
+
+ self.client_widget._update_input_visibility()
+
+ except Exception as e:
+ QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame image: {e}")
+ self._cleanup_temp_dir()
+ return
+
+ self.app.status_label.setText(f"Ready to generate to frame at {end_sec:.2f}s.")
+ self.dock_widget.show()
+ self.dock_widget.raise_()
+
def setup_creator_for_region(self, region, on_new_track=False):
self._reset_state()
self.active_region = region
From bbe9df1eccd10c387e3e536a7aa09cc5fbb7460d Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 16:36:41 +1100
Subject: [PATCH 136/155] make sure advanced options tabs dont need scroll
---
plugins/wan2gp/main.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 2c38d30fa..12a4ede6e 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -423,6 +423,7 @@ def setup_generator_tab(self):
self.tabs.addTab(gen_tab, "Video Generator")
gen_layout = QHBoxLayout(gen_tab)
left_panel = QWidget()
+ left_panel.setMinimumWidth(628)
left_layout = QVBoxLayout(left_panel)
gen_layout.addWidget(left_panel, 1)
right_panel = QWidget()
From f2324d42424d5a584388a5978b8da43dd0fa29f1 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 16:44:29 +1100
Subject: [PATCH 137/155] allow sliders to be text editable
---
plugins/wan2gp/main.py | 32 ++++++++++++++++++++++++++++----
1 file changed, 28 insertions(+), 4 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 12a4ede6e..41edebd80 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -296,11 +296,35 @@ def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1
slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal)
slider.setRange(min_val, max_val)
slider.setValue(int(initial_val * scale))
- value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}")
- value_label.setMinimumWidth(50)
- slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}"))
+
+ value_edit = self.create_widget(QLineEdit, f"{name}_label", f"{initial_val:.{precision}f}")
+ value_edit.setFixedWidth(60)
+ value_edit.setAlignment(Qt.AlignmentFlag.AlignCenter)
+
+ def sync_slider_from_text():
+ try:
+ text_value = float(value_edit.text())
+ slider_value = int(round(text_value * scale))
+
+ slider.blockSignals(True)
+ slider.setValue(slider_value)
+ slider.blockSignals(False)
+
+ actual_slider_value = slider.value()
+ if actual_slider_value != slider_value:
+ value_edit.setText(f"{actual_slider_value / scale:.{precision}f}")
+ except (ValueError, TypeError):
+ value_edit.setText(f"{slider.value() / scale:.{precision}f}")
+
+ def sync_text_from_slider(value):
+ if not value_edit.hasFocus():
+ value_edit.setText(f"{value / scale:.{precision}f}")
+
+ value_edit.editingFinished.connect(sync_slider_from_text)
+ slider.valueChanged.connect(sync_text_from_slider)
+
hbox.addWidget(slider)
- hbox.addWidget(value_label)
+ hbox.addWidget(value_edit)
return container
def _create_file_input(self, name, label_text):
From 2a867c2bfb54c3a52c2a976d33d2bf12a35fc723 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 16:53:26 +1100
Subject: [PATCH 138/155] save selected regions to project
---
videoeditor.py | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/videoeditor.py b/videoeditor.py
index 6abb4091b..ce49757b1 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -2030,6 +2030,7 @@ def new_project(self):
self.project_width = 1280
self.project_height = 720
self.timeline_widget.set_project_fps(self.project_fps)
+ self.timeline_widget.clear_all_regions()
self.timeline_widget.update()
self.undo_stack = UndoStack()
self.undo_stack.history_changed.connect(self.update_undo_redo_actions)
@@ -2042,6 +2043,7 @@ def save_project_as(self):
project_data = {
"media_pool": self.media_pool,
"clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips],
+ "selection_regions": self.timeline_widget.selection_regions,
"settings": {
"num_video_tracks": self.timeline.num_video_tracks,
"num_audio_tracks": self.timeline.num_audio_tracks,
@@ -2073,6 +2075,8 @@ def _load_project_from_path(self, path):
self.project_fps = project_settings.get("project_fps", 25.0)
self.timeline_widget.set_project_fps(self.project_fps)
+ self.timeline_widget.selection_regions = project_data.get("selection_regions", [])
+
self.media_pool = project_data.get("media_pool", [])
for p in self.media_pool: self._add_media_to_pool(p)
From 4d79760e83e9579cd3eb2ec65c493c62a42fd02d Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 17:30:20 +1100
Subject: [PATCH 139/155] set minheight
---
plugins/wan2gp/main.py | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 41edebd80..623d4afa7 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -469,11 +469,13 @@ def setup_generator_tab(self):
model_layout.addWidget(self.widgets['model_choice'], 3)
options_layout.addLayout(model_layout)
options_layout.addWidget(QLabel("Prompt:"))
- self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100)
- options_layout.addWidget(self.widgets['prompt'])
+ prompt_edit = self.create_widget(QTextEdit, 'prompt')
+ prompt_edit.setMaximumHeight(prompt_edit.fontMetrics().lineSpacing() * 5 + 15)
+ options_layout.addWidget(prompt_edit)
options_layout.addWidget(QLabel("Negative Prompt:"))
- self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60)
- options_layout.addWidget(self.widgets['negative_prompt'])
+ neg_prompt_edit = self.create_widget(QTextEdit, 'negative_prompt')
+ neg_prompt_edit.setMaximumHeight(neg_prompt_edit.fontMetrics().lineSpacing() * 3 + 15)
+ options_layout.addWidget(neg_prompt_edit)
basic_group = QGroupBox("Basic Options")
basic_layout = QFormLayout(basic_group)
res_container = QWidget()
@@ -528,7 +530,7 @@ def setup_generator_tab(self):
self._setup_adv_tab_sliding_window(advanced_tabs)
self._setup_adv_tab_misc(advanced_tabs)
options_layout.addWidget(self.advanced_group)
-
+ scroll_area.setMinimumHeight(options_widget.sizeHint().height()-260)
btn_layout = QHBoxLayout()
self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate")
self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue")
From a8fd923e49d3e246e6c050b483061a35a2f5f160 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 19:09:10 +1100
Subject: [PATCH 140/155] add entire ffmpeg library of containers and formats
to export options
---
videoeditor.py | 470 ++++++++++++++++++++++++++++++++++++++++++++-----
1 file changed, 430 insertions(+), 40 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index ce49757b1..b6673c879 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -10,7 +10,9 @@
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QHBoxLayout, QPushButton, QFileDialog, QLabel,
QScrollArea, QFrame, QProgressBar, QDialog,
- QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox)
+ QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget,
+ QListWidget, QListWidgetItem, QMessageBox, QComboBox,
+ QFormLayout, QGroupBox, QLineEdit)
from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction,
QPixmap, QImage, QDrag, QCursor, QKeyEvent)
from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread,
@@ -19,6 +21,165 @@
from plugins import PluginManager, ManagePluginsDialog
from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand
+CONTAINER_PRESETS = {
+ 'mp4': {
+ 'vcodec': 'libx264', 'acodec': 'aac',
+ 'allowed_vcodecs': ['libx264', 'libx265', 'mpeg4'],
+ 'allowed_acodecs': ['aac', 'libmp3lame'],
+ 'v_bitrate': '5M', 'a_bitrate': '192k'
+ },
+ 'matroska': {
+ 'vcodec': 'libx264', 'acodec': 'aac',
+ 'allowed_vcodecs': ['libx264', 'libx265', 'libvpx-vp9'],
+ 'allowed_acodecs': ['aac', 'libopus', 'libvorbis', 'flac'],
+ 'v_bitrate': '5M', 'a_bitrate': '192k'
+ },
+ 'mov': {
+ 'vcodec': 'libx264', 'acodec': 'aac',
+ 'allowed_vcodecs': ['libx264', 'prores_ks', 'mpeg4'],
+ 'allowed_acodecs': ['aac', 'pcm_s16le'],
+ 'v_bitrate': '8M', 'a_bitrate': '256k'
+ },
+ 'avi': {
+ 'vcodec': 'mpeg4', 'acodec': 'libmp3lame',
+ 'allowed_vcodecs': ['mpeg4', 'msmpeg4'],
+ 'allowed_acodecs': ['libmp3lame'],
+ 'v_bitrate': '5M', 'a_bitrate': '192k'
+ },
+ 'webm': {
+ 'vcodec': 'libvpx-vp9', 'acodec': 'libopus',
+ 'allowed_vcodecs': ['libvpx-vp9'],
+ 'allowed_acodecs': ['libopus', 'libvorbis'],
+ 'v_bitrate': '4M', 'a_bitrate': '192k'
+ },
+ 'wav': {
+ 'vcodec': None, 'acodec': 'pcm_s16le',
+ 'allowed_vcodecs': [], 'allowed_acodecs': ['pcm_s16le', 'pcm_s24le'],
+ 'v_bitrate': None, 'a_bitrate': None
+ },
+ 'mp3': {
+ 'vcodec': None, 'acodec': 'libmp3lame',
+ 'allowed_vcodecs': [], 'allowed_acodecs': ['libmp3lame'],
+ 'v_bitrate': None, 'a_bitrate': '192k'
+ },
+ 'flac': {
+ 'vcodec': None, 'acodec': 'flac',
+ 'allowed_vcodecs': [], 'allowed_acodecs': ['flac'],
+ 'v_bitrate': None, 'a_bitrate': None
+ },
+ 'gif': {
+ 'vcodec': 'gif', 'acodec': None,
+ 'allowed_vcodecs': ['gif'], 'allowed_acodecs': [],
+ 'v_bitrate': None, 'a_bitrate': None
+ },
+ 'oga': { # Using oga for ogg audio
+ 'vcodec': None, 'acodec': 'libvorbis',
+ 'allowed_vcodecs': [], 'allowed_acodecs': ['libvorbis', 'libopus'],
+ 'v_bitrate': None, 'a_bitrate': '192k'
+ }
+}
+
+_cached_formats = None
+_cached_video_codecs = None
+_cached_audio_codecs = None
+
+def run_ffmpeg_command(args):
+ try:
+ startupinfo = None
+ if hasattr(subprocess, 'STARTUPINFO'):
+ startupinfo = subprocess.STARTUPINFO()
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
+ startupinfo.wShowWindow = subprocess.SW_HIDE
+
+ result = subprocess.run(
+ ['ffmpeg'] + args,
+ capture_output=True, text=True, encoding='utf-8',
+ errors='ignore', startupinfo=startupinfo
+ )
+ if result.returncode != 0 and "Unrecognized option" not in result.stderr:
+ print(f"FFmpeg command failed: {' '.join(args)}\n{result.stderr}")
+ return ""
+ return result.stdout
+ except FileNotFoundError:
+ print("Error: ffmpeg not found. Please ensure it is in your system's PATH.")
+ return None
+ except Exception as e:
+ print(f"An error occurred while running ffmpeg: {e}")
+ return None
+
+def get_available_formats():
+ global _cached_formats
+ if _cached_formats is not None:
+ return _cached_formats
+
+ output = run_ffmpeg_command(['-formats'])
+ if not output:
+ _cached_formats = {}
+ return {}
+
+ formats = {}
+ lines = output.split('\n')
+ header_found = False
+ for line in lines:
+ if "---" in line:
+ header_found = True
+ continue
+ if not header_found or not line.strip():
+ continue
+
+ if line[2] == 'E':
+ parts = line[4:].strip().split(None, 1)
+ if len(parts) == 2:
+ names, description = parts
+ primary_name = names.split(',')[0].strip()
+ formats[primary_name] = description.strip()
+
+ _cached_formats = dict(sorted(formats.items()))
+ return _cached_formats
+
+def get_available_codecs(codec_type='video'):
+ global _cached_video_codecs, _cached_audio_codecs
+
+ if codec_type == 'video' and _cached_video_codecs is not None: return _cached_video_codecs
+ if codec_type == 'audio' and _cached_audio_codecs is not None: return _cached_audio_codecs
+
+ output = run_ffmpeg_command(['-encoders'])
+ if not output:
+ if codec_type == 'video': _cached_video_codecs = {}
+ else: _cached_audio_codecs = {}
+ return {}
+
+ video_codecs = {}
+ audio_codecs = {}
+ lines = output.split('\n')
+
+ header_found = False
+ for line in lines:
+ if "------" in line:
+ header_found = True
+ continue
+
+ if not header_found or not line.strip():
+ continue
+
+ parts = line.strip().split(None, 2)
+ if len(parts) < 3:
+ continue
+
+ flags, name, description = parts
+ type_flag = flags[0]
+ clean_description = re.sub(r'\s*\(codec .*\)$', '', description).strip()
+
+ if type_flag == 'V':
+ video_codecs[name] = clean_description
+ elif type_flag == 'A':
+ audio_codecs[name] = clean_description
+
+ _cached_video_codecs = dict(sorted(video_codecs.items()))
+ _cached_audio_codecs = dict(sorted(audio_codecs.items()))
+
+ return _cached_video_codecs if codec_type == 'video' else _cached_audio_codecs
+
class TimelineClip:
def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, media_type, group_id):
self.id = str(uuid.uuid4())
@@ -1259,19 +1420,235 @@ class SettingsDialog(QDialog):
def __init__(self, parent_settings, parent=None):
super().__init__(parent)
self.setWindowTitle("Settings")
- self.setMinimumWidth(350)
+ self.setMinimumWidth(450)
layout = QVBoxLayout(self)
self.confirm_on_exit_checkbox = QCheckBox("Confirm before exiting")
self.confirm_on_exit_checkbox.setChecked(parent_settings.get("confirm_on_exit", True))
layout.addWidget(self.confirm_on_exit_checkbox)
+
+ export_path_group = QGroupBox("Default Export Path (for new projects)")
+ export_path_layout = QHBoxLayout()
+ self.default_export_path_edit = QLineEdit()
+ self.default_export_path_edit.setPlaceholderText("Optional: e.g., C:/Users/YourUser/Videos/Exports")
+ self.default_export_path_edit.setText(parent_settings.get("default_export_path", ""))
+ browse_button = QPushButton("Browse...")
+ browse_button.clicked.connect(self.browse_default_export_path)
+ export_path_layout.addWidget(self.default_export_path_edit)
+ export_path_layout.addWidget(browse_button)
+ export_path_group.setLayout(export_path_layout)
+ layout.addWidget(export_path_group)
+
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)
layout.addWidget(button_box)
+ def browse_default_export_path(self):
+ path = QFileDialog.getExistingDirectory(self, "Select Default Export Folder", self.default_export_path_edit.text())
+ if path:
+ self.default_export_path_edit.setText(path)
+
def get_settings(self):
return {
- "confirm_on_exit": self.confirm_on_exit_checkbox.isChecked()
+ "confirm_on_exit": self.confirm_on_exit_checkbox.isChecked(),
+ "default_export_path": self.default_export_path_edit.text(),
+ }
+
+class ExportDialog(QDialog):
+ def __init__(self, default_path, parent=None):
+ super().__init__(parent)
+ self.setWindowTitle("Export Settings")
+ self.setMinimumWidth(550)
+
+ self.video_bitrate_options = ["500k", "1M", "2.5M", "5M", "8M", "15M", "Custom..."]
+ self.audio_bitrate_options = ["96k", "128k", "192k", "256k", "320k", "Custom..."]
+ self.display_ext_map = {'matroska': 'mkv', 'oga': 'ogg'}
+
+ self.layout = QVBoxLayout(self)
+ self.formats = get_available_formats()
+ self.video_codecs = get_available_codecs('video')
+ self.audio_codecs = get_available_codecs('audio')
+
+ self._setup_ui()
+
+ self.path_edit.setText(default_path)
+ self.on_advanced_toggled(False)
+
+ def _setup_ui(self):
+ output_group = QGroupBox("Output File")
+ output_layout = QHBoxLayout()
+ self.path_edit = QLineEdit()
+ browse_button = QPushButton("Browse...")
+ browse_button.clicked.connect(self.browse_output_path)
+ output_layout.addWidget(self.path_edit)
+ output_layout.addWidget(browse_button)
+ output_group.setLayout(output_layout)
+ self.layout.addWidget(output_group)
+
+ self.container_combo = QComboBox()
+ self.container_combo.currentIndexChanged.connect(self.on_container_changed)
+
+ self.advanced_formats_checkbox = QCheckBox("Show Advanced Options")
+ self.advanced_formats_checkbox.toggled.connect(self.on_advanced_toggled)
+
+ container_layout = QFormLayout()
+ container_layout.addRow("Format Preset:", self.container_combo)
+ container_layout.addRow(self.advanced_formats_checkbox)
+ self.layout.addLayout(container_layout)
+
+ self.video_group = QGroupBox("Video Settings")
+ video_layout = QFormLayout()
+ self.video_codec_combo = QComboBox()
+ video_layout.addRow("Video Codec:", self.video_codec_combo)
+ self.v_bitrate_combo = QComboBox()
+ self.v_bitrate_combo.addItems(self.video_bitrate_options)
+ self.v_bitrate_custom_edit = QLineEdit()
+ self.v_bitrate_custom_edit.setPlaceholderText("e.g., 6500k")
+ self.v_bitrate_custom_edit.hide()
+ self.v_bitrate_combo.currentTextChanged.connect(self.on_v_bitrate_changed)
+ video_layout.addRow("Video Bitrate:", self.v_bitrate_combo)
+ video_layout.addRow(self.v_bitrate_custom_edit)
+ self.video_group.setLayout(video_layout)
+ self.layout.addWidget(self.video_group)
+
+ self.audio_group = QGroupBox("Audio Settings")
+ audio_layout = QFormLayout()
+ self.audio_codec_combo = QComboBox()
+ audio_layout.addRow("Audio Codec:", self.audio_codec_combo)
+ self.a_bitrate_combo = QComboBox()
+ self.a_bitrate_combo.addItems(self.audio_bitrate_options)
+ self.a_bitrate_custom_edit = QLineEdit()
+ self.a_bitrate_custom_edit.setPlaceholderText("e.g., 256k")
+ self.a_bitrate_custom_edit.hide()
+ self.a_bitrate_combo.currentTextChanged.connect(self.on_a_bitrate_changed)
+ audio_layout.addRow("Audio Bitrate:", self.a_bitrate_combo)
+ audio_layout.addRow(self.a_bitrate_custom_edit)
+ self.audio_group.setLayout(audio_layout)
+ self.layout.addWidget(self.audio_group)
+
+ self.button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel)
+ self.button_box.accepted.connect(self.accept)
+ self.button_box.rejected.connect(self.reject)
+ self.layout.addWidget(self.button_box)
+
+ def _populate_combo(self, combo, data_dict, filter_keys=None):
+ current_selection = combo.currentData()
+ combo.blockSignals(True)
+ combo.clear()
+
+ keys_to_show = filter_keys if filter_keys is not None else data_dict.keys()
+ for codename in keys_to_show:
+ if codename in data_dict:
+ desc = data_dict[codename]
+ display_name = self.display_ext_map.get(codename, codename) if combo is self.container_combo else codename
+ combo.addItem(f"{desc} ({display_name})", codename)
+
+ new_index = combo.findData(current_selection)
+ combo.setCurrentIndex(new_index if new_index != -1 else 0)
+ combo.blockSignals(False)
+
+ def on_advanced_toggled(self, checked):
+ self._populate_combo(self.container_combo, self.formats, None if checked else CONTAINER_PRESETS.keys())
+
+ if not checked:
+ mp4_index = self.container_combo.findData("mp4")
+ if mp4_index != -1: self.container_combo.setCurrentIndex(mp4_index)
+
+ self.on_container_changed(self.container_combo.currentIndex())
+
+ def apply_preset(self, container_codename):
+ preset = CONTAINER_PRESETS.get(container_codename, {})
+
+ vcodec = preset.get('vcodec')
+ self.video_group.setEnabled(vcodec is not None)
+ if vcodec:
+ vcodec_idx = self.video_codec_combo.findData(vcodec)
+ self.video_codec_combo.setCurrentIndex(vcodec_idx if vcodec_idx != -1 else 0)
+
+ v_bitrate = preset.get('v_bitrate')
+ self.v_bitrate_combo.setEnabled(v_bitrate is not None)
+ if v_bitrate:
+ v_bitrate_idx = self.v_bitrate_combo.findText(v_bitrate)
+ if v_bitrate_idx != -1: self.v_bitrate_combo.setCurrentIndex(v_bitrate_idx)
+ else:
+ self.v_bitrate_combo.setCurrentText("Custom...")
+ self.v_bitrate_custom_edit.setText(v_bitrate)
+
+ acodec = preset.get('acodec')
+ self.audio_group.setEnabled(acodec is not None)
+ if acodec:
+ acodec_idx = self.audio_codec_combo.findData(acodec)
+ self.audio_codec_combo.setCurrentIndex(acodec_idx if acodec_idx != -1 else 0)
+
+ a_bitrate = preset.get('a_bitrate')
+ self.a_bitrate_combo.setEnabled(a_bitrate is not None)
+ if a_bitrate:
+ a_bitrate_idx = self.a_bitrate_combo.findText(a_bitrate)
+ if a_bitrate_idx != -1: self.a_bitrate_combo.setCurrentIndex(a_bitrate_idx)
+ else:
+ self.a_bitrate_combo.setCurrentText("Custom...")
+ self.a_bitrate_custom_edit.setText(a_bitrate)
+
+ def on_v_bitrate_changed(self, text):
+ self.v_bitrate_custom_edit.setVisible(text == "Custom...")
+
+ def on_a_bitrate_changed(self, text):
+ self.a_bitrate_custom_edit.setVisible(text == "Custom...")
+
+ def browse_output_path(self):
+ container_codename = self.container_combo.currentData()
+ all_formats_desc = [f"{desc} (*.{name})" for name, desc in self.formats.items()]
+ filter_str = ";;".join(all_formats_desc)
+ current_desc = self.formats.get(container_codename, "Custom Format")
+ specific_filter = f"{current_desc} (*.{container_codename})"
+ final_filter = f"{specific_filter};;{filter_str};;All Files (*)"
+
+ path, _ = QFileDialog.getSaveFileName(self, "Save Video As", self.path_edit.text(), final_filter)
+ if path: self.path_edit.setText(path)
+
+ def on_container_changed(self, index):
+ if index == -1: return
+ new_container_codename = self.container_combo.itemData(index)
+
+ is_advanced = self.advanced_formats_checkbox.isChecked()
+ preset = CONTAINER_PRESETS.get(new_container_codename)
+
+ if not is_advanced and preset:
+ v_filter = preset.get('allowed_vcodecs')
+ a_filter = preset.get('allowed_acodecs')
+ self._populate_combo(self.video_codec_combo, self.video_codecs, v_filter)
+ self._populate_combo(self.audio_codec_combo, self.audio_codecs, a_filter)
+ else:
+ self._populate_combo(self.video_codec_combo, self.video_codecs)
+ self._populate_combo(self.audio_codec_combo, self.audio_codecs)
+
+ if new_container_codename:
+ self.update_output_path_extension(new_container_codename)
+ self.apply_preset(new_container_codename)
+
+ def update_output_path_extension(self, new_container_codename):
+ current_path = self.path_edit.text()
+ if not current_path: return
+ directory, filename = os.path.split(current_path)
+ basename, _ = os.path.splitext(filename)
+ ext = self.display_ext_map.get(new_container_codename, new_container_codename)
+ new_path = os.path.join(directory, f"{basename}.{ext}")
+ self.path_edit.setText(new_path)
+
+ def get_export_settings(self):
+ v_bitrate = self.v_bitrate_combo.currentText()
+ if v_bitrate == "Custom...": v_bitrate = self.v_bitrate_custom_edit.text()
+
+ a_bitrate = self.a_bitrate_combo.currentText()
+ if a_bitrate == "Custom...": a_bitrate = self.a_bitrate_custom_edit.text()
+
+ return {
+ "output_path": self.path_edit.text(),
+ "container": self.container_combo.currentData(),
+ "vcodec": self.video_codec_combo.currentData() if self.video_group.isEnabled() else None,
+ "v_bitrate": v_bitrate if self.v_bitrate_combo.isEnabled() else None,
+ "acodec": self.audio_codec_combo.currentData() if self.audio_group.isEnabled() else None,
+ "a_bitrate": a_bitrate if self.a_bitrate_combo.isEnabled() else None,
}
class MediaListWidget(QListWidget):
@@ -1391,6 +1768,7 @@ def __init__(self, project_to_load=None):
self.export_thread = None
self.export_worker = None
self.current_project_path = None
+ self.last_export_path = None
self.settings = {}
self.settings_file = "settings.json"
self.is_shutting_down = False
@@ -1975,7 +2353,7 @@ def advance_playback_frame(self):
def _load_settings(self):
self.settings_file_was_loaded = False
- defaults = {"window_visibility": {"project_media": False}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True}
+ defaults = {"window_visibility": {"project_media": False}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True, "default_export_path": ""}
if os.path.exists(self.settings_file):
try:
with open(self.settings_file, "r") as f: self.settings = json.load(f)
@@ -2026,6 +2404,7 @@ def new_project(self):
self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1
self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list()
self.current_project_path = None; self.stop_playback()
+ self.last_export_path = None
self.project_fps = 25.0
self.project_width = 1280
self.project_height = 720
@@ -2044,6 +2423,7 @@ def save_project_as(self):
"media_pool": self.media_pool,
"clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips],
"selection_regions": self.timeline_widget.selection_regions,
+ "last_export_path": self.last_export_path,
"settings": {
"num_video_tracks": self.timeline.num_video_tracks,
"num_audio_tracks": self.timeline.num_audio_tracks,
@@ -2074,11 +2454,11 @@ def _load_project_from_path(self, path):
self.project_height = project_settings.get("project_height", 720)
self.project_fps = project_settings.get("project_fps", 25.0)
self.timeline_widget.set_project_fps(self.project_fps)
-
+ self.last_export_path = project_data.get("last_export_path")
self.timeline_widget.selection_regions = project_data.get("selection_regions", [])
- self.media_pool = project_data.get("media_pool", [])
- for p in self.media_pool: self._add_media_to_pool(p)
+ media_pool_paths = project_data.get("media_pool", [])
+ for p in media_pool_paths: self._add_media_to_pool(p)
for clip_data in project_data["clips"]:
if not os.path.exists(clip_data["source_path"]):
@@ -2520,9 +2900,41 @@ def action():
def export_video(self):
- if not self.timeline.clips: self.status_label.setText("Timeline is empty."); return
- output_path, _ = QFileDialog.getSaveFileName(self, "Save Video As", "", "MP4 Files (*.mp4)")
- if not output_path: return
+ if not self.timeline.clips:
+ self.status_label.setText("Timeline is empty.")
+ return
+
+ default_path = ""
+ if self.last_export_path and os.path.isdir(os.path.dirname(self.last_export_path)):
+ default_path = self.last_export_path
+ elif self.settings.get("default_export_path") and os.path.isdir(self.settings.get("default_export_path")):
+ proj_basename = "output"
+ if self.current_project_path:
+ _, proj_file = os.path.split(self.current_project_path)
+ proj_basename, _ = os.path.splitext(proj_file)
+
+ default_path = os.path.join(self.settings["default_export_path"], f"{proj_basename}_export.mp4")
+ elif self.current_project_path:
+ proj_dir, proj_file = os.path.split(self.current_project_path)
+ proj_basename, _ = os.path.splitext(proj_file)
+ default_path = os.path.join(proj_dir, f"{proj_basename}_export.mp4")
+ else:
+ default_path = "output.mp4"
+
+ default_path = os.path.normpath(default_path)
+
+ dialog = ExportDialog(default_path, self)
+ if dialog.exec() != QDialog.DialogCode.Accepted:
+ self.status_label.setText("Export canceled.")
+ return
+
+ settings = dialog.get_export_settings()
+ output_path = settings["output_path"]
+ if not output_path:
+ self.status_label.setText("Export failed: No output path specified.")
+ return
+
+ self.last_export_path = output_path
w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration()
sample_rate, channel_layout = '44100', 'stereo'
@@ -2533,9 +2945,7 @@ def export_video(self):
[c for c in self.timeline.clips if c.track_type == 'video'],
key=lambda c: c.track_index
)
-
input_nodes = {}
-
for clip in all_video_clips:
if clip.source_path not in input_nodes:
if clip.media_type == 'image':
@@ -2544,31 +2954,13 @@ def export_video(self):
input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
clip_source_node = input_nodes[clip.source_path]
-
if clip.media_type == 'image':
- segment_stream = (
- clip_source_node.video
- .trim(duration=clip.duration_sec)
- .setpts('PTS-STARTPTS')
- )
+ segment_stream = (clip_source_node.video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS'))
else:
- segment_stream = (
- clip_source_node.video
- .trim(start=clip.clip_start_sec, duration=clip.duration_sec)
- .setpts('PTS-STARTPTS')
- )
+ segment_stream = (clip_source_node.video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS'))
- processed_segment = (
- segment_stream
- .filter('scale', w, h, force_original_aspect_ratio='decrease')
- .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
- )
-
- video_stream = ffmpeg.overlay(
- video_stream,
- processed_segment,
- enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})'
- )
+ processed_segment = (segment_stream.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black'))
+ video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})')
final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps)
@@ -2576,22 +2968,17 @@ def export_video(self):
for i in range(self.timeline.num_audio_tracks):
track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec)
if not track_clips: continue
-
track_segments = []
last_end = track_clips[0].timeline_start_sec
-
for clip in track_clips:
if clip.source_path not in input_nodes:
input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
-
gap = clip.timeline_start_sec - last_end
if gap > 0.01:
track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi'))
-
a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS')
track_segments.append(a_seg)
last_end = clip.timeline_end_sec
-
track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_sec * 1000)}ms', all=True))
if track_audio_streams:
@@ -2599,7 +2986,10 @@ def export_video(self):
else:
final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur}', f='lavfi')
- output_args = {'vcodec': 'libx264', 'acodec': 'aac', 'pix_fmt': 'yuv420p', 'b:v': '5M'}
+ output_args = {'vcodec': settings['vcodec'], 'acodec': settings['acodec'], 'pix_fmt': 'yuv420p'}
+ if settings['v_bitrate']: output_args['b:v'] = settings['v_bitrate']
+ if settings['a_bitrate']: output_args['b:a'] = settings['a_bitrate']
+
try:
ffmpeg_cmd = ffmpeg.output(final_video, final_audio, output_path, **output_args).overwrite_output().compile()
self.progress_bar.setVisible(True); self.progress_bar.setValue(0); self.status_label.setText("Exporting...")
From cdcde06778a07ee55be95f6bc64dd9d1e9e4bfd3 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 21:26:42 +1100
Subject: [PATCH 141/155] use ms instead of sec globally
---
plugins/wan2gp/main.py | 67 ++---
videoeditor.py | 612 +++++++++++++++++++++--------------------
2 files changed, 347 insertions(+), 332 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 623d4afa7..473baf2e4 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -124,11 +124,12 @@ def enterEvent(self, event):
super().enterEvent(event)
self.media_player.play()
if not self.plugin.active_region or self.duration == 0: return
- start_sec, _ = self.plugin.active_region
+ start_ms, _ = self.plugin.active_region
timeline = self.app.timeline_widget
video_rect, audio_rect = None, None
- x = timeline.sec_to_x(start_sec)
- w = int(self.duration * timeline.pixels_per_second)
+ x = timeline.ms_to_x(start_ms)
+ duration_ms = int(self.duration * 1000)
+ w = int(duration_ms * timeline.pixels_per_ms)
if self.plugin.insert_on_new_track:
video_y = timeline.TIMESCALE_HEIGHT
video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT)
@@ -1705,9 +1706,9 @@ def on_timeline_context_menu(self, menu, event):
region = self.app.timeline_widget.get_region_at_pos(event.pos())
if region:
menu.addSeparator()
- start_sec, end_sec = region
- start_data, _, _ = self.app.get_frame_data_at_time(start_sec)
- end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+ start_ms, end_ms = region
+ start_data, _, _ = self.app.get_frame_data_at_time(start_ms)
+ end_data, _, _ = self.app.get_frame_data_at_time(end_ms)
if start_data and end_data:
join_action = menu.addAction("Join Frames With AI")
@@ -1745,9 +1746,9 @@ def setup_generator_for_region(self, region, on_new_track=False):
else:
print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
- start_sec, end_sec = region
- start_data, w, h = self.app.get_frame_data_at_time(start_sec)
- end_data, _, _ = self.app.get_frame_data_at_time(end_sec)
+ start_ms, end_ms = region
+ start_data, w, h = self.app.get_frame_data_at_time(start_ms)
+ end_data, _, _ = self.app.get_frame_data_at_time(end_ms)
if not start_data or not end_data:
QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.")
return
@@ -1761,11 +1762,11 @@ def setup_generator_for_region(self, region, on_new_track=False):
QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path)
QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
- duration_sec = end_sec - start_sec
+ duration_ms = end_ms - start_ms
wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
- video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+ video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
widgets = self.client_widget.widgets
for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']:
@@ -1786,7 +1787,7 @@ def setup_generator_for_region(self, region, on_new_track=False):
QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}")
self._cleanup_temp_dir()
return
- self.app.status_label.setText(f"Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.")
+ self.app.status_label.setText(f"Ready to join frames from {start_ms / 1000.0:.2f}s to {end_ms / 1000.0:.2f}s.")
self.dock_widget.show()
self.dock_widget.raise_()
@@ -1805,8 +1806,8 @@ def setup_generator_from_start(self, region, on_new_track=False):
else:
print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
- start_sec, end_sec = region
- start_data, w, h = self.app.get_frame_data_at_time(start_sec)
+ start_ms, end_ms = region
+ start_data, w, h = self.app.get_frame_data_at_time(start_ms)
if not start_data:
QMessageBox.warning(self.app, "Frame Error", "Could not extract start frame.")
return
@@ -1818,11 +1819,11 @@ def setup_generator_from_start(self, region, on_new_track=False):
self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png")
QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path)
- duration_sec = end_sec - start_sec
+ duration_ms = end_ms - start_ms
wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
- video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+ video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
widgets = self.client_widget.widgets
widgets['video_length'].setValue(video_length_frames)
@@ -1845,7 +1846,7 @@ def setup_generator_from_start(self, region, on_new_track=False):
self._cleanup_temp_dir()
return
- self.app.status_label.setText(f"Ready to generate from frame at {start_sec:.2f}s.")
+ self.app.status_label.setText(f"Ready to generate from frame at {start_ms / 1000.0:.2f}s.")
self.dock_widget.show()
self.dock_widget.raise_()
@@ -1864,8 +1865,8 @@ def setup_generator_to_end(self, region, on_new_track=False):
else:
print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.")
- start_sec, end_sec = region
- end_data, w, h = self.app.get_frame_data_at_time(end_sec)
+ start_ms, end_ms = region
+ end_data, w, h = self.app.get_frame_data_at_time(end_ms)
if not end_data:
QMessageBox.warning(self.app, "Frame Error", "Could not extract end frame.")
return
@@ -1877,11 +1878,11 @@ def setup_generator_to_end(self, region, on_new_track=False):
self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png")
QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
- duration_sec = end_sec - start_sec
+ duration_ms = end_ms - start_ms
wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
- video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+ video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
widgets = self.client_widget.widgets
widgets['video_length'].setValue(video_length_frames)
@@ -1915,7 +1916,7 @@ def setup_generator_to_end(self, region, on_new_track=False):
self._cleanup_temp_dir()
return
- self.app.status_label.setText(f"Ready to generate to frame at {end_sec:.2f}s.")
+ self.app.status_label.setText(f"Ready to generate to frame at {end_ms / 1000.0:.2f}s.")
self.dock_widget.show()
self.dock_widget.raise_()
@@ -1938,17 +1939,17 @@ def setup_creator_for_region(self, region, on_new_track=False):
target_h = self.app.project_height
self.client_widget.set_resolution_from_target(target_w, target_h)
- start_sec, end_sec = region
- duration_sec = end_sec - start_sec
+ start_ms, end_ms = region
+ duration_ms = end_ms - start_ms
wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
- video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16)
+ video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
self.client_widget.widgets['video_length'].setValue(video_length_frames)
self.client_widget.widgets['mode_t'].setChecked(True)
- self.app.status_label.setText(f"Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.")
+ self.app.status_label.setText(f"Ready to create video from {start_ms / 1000.0:.2f}s to {end_ms / 1000.0:.2f}s.")
self.dock_widget.show()
self.dock_widget.raise_()
@@ -1958,29 +1959,29 @@ def insert_generated_clip(self, video_path):
self.app.status_label.setText("Error: No active region to insert into."); return
if not os.path.exists(video_path):
self.app.status_label.setText(f"Error: Output file not found: {video_path}"); return
- start_sec, end_sec = self.active_region
+ start_ms, end_ms = self.active_region
def complex_insertion_action():
self.app._add_media_files_to_project([video_path])
media_info = self.app.media_properties.get(video_path)
if not media_info: raise ValueError("Could not probe inserted clip.")
- actual_duration, has_audio = media_info['duration'], media_info['has_audio']
+ actual_duration_ms, has_audio = media_info['duration_ms'], media_info['has_audio']
if self.insert_on_new_track:
self.app.timeline.num_video_tracks += 1
video_track_index = self.app.timeline.num_video_tracks
audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None
else:
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec)
- for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec)
- clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec]
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_ms)
+ for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_ms)
+ clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_end_ms <= end_ms]
for clip in clips_to_remove:
if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip)
video_track_index, audio_track_index = 1, 1 if has_audio else None
group_id = str(uuid.uuid4())
- new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id)
+ new_clip = TimelineClip(video_path, start_ms, 0, actual_duration_ms, video_track_index, 'video', 'video', group_id)
self.app.timeline.add_clip(new_clip)
if audio_track_index:
if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index
- audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id)
+ audio_clip = TimelineClip(video_path, start_ms, 0, actual_duration_ms, audio_track_index, 'audio', 'video', group_id)
self.app.timeline.add_clip(audio_clip)
try:
self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action)
diff --git a/videoeditor.py b/videoeditor.py
index b6673c879..5126b79b5 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -181,20 +181,20 @@ def get_available_codecs(codec_type='video'):
return _cached_video_codecs if codec_type == 'video' else _cached_audio_codecs
class TimelineClip:
- def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, media_type, group_id):
+ def __init__(self, source_path, timeline_start_ms, clip_start_ms, duration_ms, track_index, track_type, media_type, group_id):
self.id = str(uuid.uuid4())
self.source_path = source_path
- self.timeline_start_sec = timeline_start_sec
- self.clip_start_sec = clip_start_sec
- self.duration_sec = duration_sec
+ self.timeline_start_ms = int(timeline_start_ms)
+ self.clip_start_ms = int(clip_start_ms)
+ self.duration_ms = int(duration_ms)
self.track_index = track_index
self.track_type = track_type
self.media_type = media_type
self.group_id = group_id
@property
- def timeline_end_sec(self):
- return self.timeline_start_sec + self.duration_sec
+ def timeline_end_ms(self):
+ return self.timeline_start_ms + self.duration_ms
class Timeline:
def __init__(self):
@@ -204,19 +204,20 @@ def __init__(self):
def add_clip(self, clip):
self.clips.append(clip)
- self.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.clips.sort(key=lambda c: c.timeline_start_ms)
def get_total_duration(self):
if not self.clips: return 0
- return max(c.timeline_end_sec for c in self.clips)
+ return max(c.timeline_end_ms for c in self.clips)
class ExportWorker(QObject):
progress = pyqtSignal(int)
finished = pyqtSignal(str)
- def __init__(self, ffmpeg_cmd, total_duration):
+ def __init__(self, ffmpeg_cmd, total_duration_ms):
super().__init__()
self.ffmpeg_cmd = ffmpeg_cmd
- self.total_duration_secs = total_duration
+ self.total_duration_ms = total_duration_ms
+
def run_export(self):
try:
process = subprocess.Popen(self.ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, encoding="utf-8")
@@ -224,10 +225,10 @@ def run_export(self):
for line in iter(process.stdout.readline, ""):
match = time_pattern.search(line)
if match:
- hours, minutes, seconds = [int(g) for g in match.groups()[:3]]
- processed_time = hours * 3600 + minutes * 60 + seconds
- if self.total_duration_secs > 0:
- percentage = int((processed_time / self.total_duration_secs) * 100)
+ h, m, s, cs = [int(g) for g in match.groups()]
+ processed_ms = (h * 3600 + m * 60 + s) * 1000 + cs * 10
+ if self.total_duration_ms > 0:
+ percentage = int((processed_ms / self.total_duration_ms) * 100)
self.progress.emit(percentage)
process.stdout.close()
return_code = process.wait()
@@ -250,7 +251,7 @@ class TimelineWidget(QWidget):
split_requested = pyqtSignal(object)
delete_clip_requested = pyqtSignal(object)
delete_clips_requested = pyqtSignal(list)
- playhead_moved = pyqtSignal(float)
+ playhead_moved = pyqtSignal(int)
split_region_requested = pyqtSignal(list)
split_all_regions_requested = pyqtSignal(list)
join_region_requested = pyqtSignal(list)
@@ -266,14 +267,14 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
super().__init__(parent)
self.timeline = timeline_model
self.settings = settings
- self.playhead_pos_sec = 0.0
- self.view_start_sec = 0.0
+ self.playhead_pos_ms = 0
+ self.view_start_ms = 0
self.panning = False
self.pan_start_pos = QPoint()
- self.pan_start_view_sec = 0.0
+ self.pan_start_view_ms = 0
- self.pixels_per_second = 50.0
- self.max_pixels_per_second = 1.0
+ self.pixels_per_ms = 0.05
+ self.max_pixels_per_ms = 1.0
self.project_fps = 25.0
self.set_project_fps(project_fps)
@@ -290,7 +291,7 @@ def __init__(self, timeline_model, settings, project_fps, parent=None):
self.dragging_selection_region = None
self.drag_start_pos = QPoint()
self.drag_original_clip_states = {}
- self.selection_drag_start_sec = 0.0
+ self.selection_drag_start_ms = 0
self.drag_selection_start_values = None
self.drag_start_state = None
@@ -325,12 +326,12 @@ def set_hover_preview_rects(self, video_rect, audio_rect):
def set_project_fps(self, fps):
self.project_fps = fps if fps > 0 else 25.0
- self.max_pixels_per_second = self.project_fps * 20
- self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second)
+ self.max_pixels_per_ms = (self.project_fps * 20) / 1000.0
+ self.pixels_per_ms = min(self.pixels_per_ms, self.max_pixels_per_ms)
self.update()
- def sec_to_x(self, sec): return self.HEADER_WIDTH + int((sec - self.view_start_sec) * self.pixels_per_second)
- def x_to_sec(self, x): return self.view_start_sec + float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else self.view_start_sec
+ def ms_to_x(self, ms): return self.HEADER_WIDTH + int((ms - self.view_start_ms) * self.pixels_per_ms)
+ def x_to_ms(self, x): return self.view_start_ms + int(float(x - self.HEADER_WIDTH) / self.pixels_per_ms) if x > self.HEADER_WIDTH and self.pixels_per_ms > 0 else self.view_start_ms
def paintEvent(self, event):
painter = QPainter(self)
@@ -438,11 +439,12 @@ def draw_headers(self, painter):
painter.restore()
- def _format_timecode(self, seconds):
- if abs(seconds) < 1e-9: seconds = 0
- sign = "-" if seconds < 0 else ""
- seconds = abs(seconds)
+ def _format_timecode(self, total_ms):
+ if abs(total_ms) < 1: total_ms = 0
+ sign = "-" if total_ms < 0 else ""
+ total_ms = abs(total_ms)
+ seconds = total_ms / 1000.0
h = int(seconds / 3600)
m = int((seconds % 3600) / 60)
s = seconds % 60
@@ -464,51 +466,51 @@ def draw_timescale(self, painter):
painter.fillRect(QRect(self.HEADER_WIDTH, 0, self.width() - self.HEADER_WIDTH, self.TIMESCALE_HEIGHT), QColor("#222"))
painter.drawLine(self.HEADER_WIDTH, self.TIMESCALE_HEIGHT - 1, self.width(), self.TIMESCALE_HEIGHT - 1)
- frame_dur = 1.0 / self.project_fps
- intervals = [
- frame_dur, 2*frame_dur, 5*frame_dur, 10*frame_dur,
- 0.1, 0.2, 0.5, 1, 2, 5, 10, 15, 30,
- 60, 120, 300, 600, 900, 1800,
- 3600, 2*3600, 5*3600, 10*3600
+ frame_dur_ms = 1000.0 / self.project_fps
+ intervals_ms = [
+ int(frame_dur_ms), int(2*frame_dur_ms), int(5*frame_dur_ms), int(10*frame_dur_ms),
+ 100, 200, 500, 1000, 2000, 5000, 10000, 15000, 30000,
+ 60000, 120000, 300000, 600000, 900000, 1800000,
+ 3600000, 2*3600000, 5*3600000, 10*3600000
]
min_pixel_dist = 70
- major_interval = next((i for i in intervals if i * self.pixels_per_second > min_pixel_dist), intervals[-1])
+ major_interval = next((i for i in intervals_ms if i * self.pixels_per_ms > min_pixel_dist), intervals_ms[-1])
minor_interval = 0
for divisor in [5, 4, 2]:
- if (major_interval / divisor) * self.pixels_per_second > 10:
+ if (major_interval / divisor) * self.pixels_per_ms > 10:
minor_interval = major_interval / divisor
break
- start_sec = self.x_to_sec(self.HEADER_WIDTH)
- end_sec = self.x_to_sec(self.width())
+ start_ms = self.x_to_ms(self.HEADER_WIDTH)
+ end_ms = self.x_to_ms(self.width())
- def draw_ticks(interval, height):
- if interval <= 1e-6: return
- start_tick_num = int(start_sec / interval)
- end_tick_num = int(end_sec / interval) + 1
+ def draw_ticks(interval_ms, height):
+ if interval_ms < 1: return
+ start_tick_num = int(start_ms / interval_ms)
+ end_tick_num = int(end_ms / interval_ms) + 1
for i in range(start_tick_num, end_tick_num + 1):
- t_sec = i * interval
- x = self.sec_to_x(t_sec)
+ t_ms = i * interval_ms
+ x = self.ms_to_x(t_ms)
if x > self.width(): break
if x >= self.HEADER_WIDTH:
painter.drawLine(x, self.TIMESCALE_HEIGHT - height, x, self.TIMESCALE_HEIGHT)
- if frame_dur * self.pixels_per_second > 4:
- draw_ticks(frame_dur, 3)
+ if frame_dur_ms * self.pixels_per_ms > 4:
+ draw_ticks(frame_dur_ms, 3)
if minor_interval > 0:
draw_ticks(minor_interval, 6)
- start_major_tick = int(start_sec / major_interval)
- end_major_tick = int(end_sec / major_interval) + 1
+ start_major_tick = int(start_ms / major_interval)
+ end_major_tick = int(end_ms / major_interval) + 1
for i in range(start_major_tick, end_major_tick + 1):
- t_sec = i * major_interval
- x = self.sec_to_x(t_sec)
+ t_ms = i * major_interval
+ x = self.ms_to_x(t_ms)
if x > self.width() + 50: break
if x >= self.HEADER_WIDTH - 50:
painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT)
- label = self._format_timecode(t_sec)
+ label = self._format_timecode(t_ms)
label_width = font_metrics.horizontalAdvance(label)
label_x = x - label_width // 2
if label_x < self.HEADER_WIDTH:
@@ -525,8 +527,8 @@ def get_clip_rect(self, clip):
visual_index = clip.track_index - 1
y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT
- x = self.sec_to_x(clip.timeline_start_sec)
- w = int(clip.duration_sec * self.pixels_per_second)
+ x = self.ms_to_x(clip.timeline_start_ms)
+ w = int(clip.duration_ms * self.pixels_per_ms)
clip_height = self.TRACK_HEIGHT - 10
y += (self.TRACK_HEIGHT - clip_height) / 2
return QRectF(x, y, w, clip_height)
@@ -598,16 +600,16 @@ def draw_tracks_and_clips(self, painter):
painter.restore()
def draw_selections(self, painter):
- for start_sec, end_sec in self.selection_regions:
- x = self.sec_to_x(start_sec)
- w = int((end_sec - start_sec) * self.pixels_per_second)
+ for start_ms, end_ms in self.selection_regions:
+ x = self.ms_to_x(start_ms)
+ w = int((end_ms - start_ms) * self.pixels_per_ms)
selection_rect = QRectF(x, self.TIMESCALE_HEIGHT, w, self.height() - self.TIMESCALE_HEIGHT)
painter.fillRect(selection_rect, QColor(100, 100, 255, 80))
painter.setPen(QColor(150, 150, 255, 150))
painter.drawRect(selection_rect)
def draw_playhead(self, painter):
- playhead_x = self.sec_to_x(self.playhead_pos_sec)
+ playhead_x = self.ms_to_x(self.playhead_pos_ms)
painter.setPen(QPen(QColor("red"), 2))
painter.drawLine(playhead_x, 0, playhead_x, self.height())
@@ -634,48 +636,48 @@ def y_to_track_info(self, y):
return None
- def _snap_time_if_needed(self, time_sec):
- frame_duration = 1.0 / self.project_fps
- if frame_duration > 0 and frame_duration * self.pixels_per_second > 4:
- frame_number = round(time_sec / frame_duration)
- return frame_number * frame_duration
- return time_sec
+ def _snap_time_if_needed(self, time_ms):
+ frame_duration_ms = 1000.0 / self.project_fps
+ if frame_duration_ms > 0 and frame_duration_ms * self.pixels_per_ms > 4:
+ frame_number = round(time_ms / frame_duration_ms)
+ return int(frame_number * frame_duration_ms)
+ return int(time_ms)
def get_region_at_pos(self, pos: QPoint):
if pos.y() <= self.TIMESCALE_HEIGHT or pos.x() <= self.HEADER_WIDTH:
return None
- clicked_sec = self.x_to_sec(pos.x())
+ clicked_ms = self.x_to_ms(pos.x())
for region in reversed(self.selection_regions):
- if region[0] <= clicked_sec <= region[1]:
+ if region[0] <= clicked_ms <= region[1]:
return region
return None
def wheelEvent(self, event: QMouseEvent):
delta = event.angleDelta().y()
zoom_factor = 1.15
- old_pps = self.pixels_per_second
+ old_pps = self.pixels_per_ms
if delta > 0:
new_pps = old_pps * zoom_factor
else:
new_pps = old_pps / zoom_factor
- min_pps = 1 / (3600 * 10)
- new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second))
+ min_pps = 1 / (3600 * 10 * 1000)
+ new_pps = max(min_pps, min(new_pps, self.max_pixels_per_ms))
if abs(new_pps - old_pps) < 1e-9:
return
if event.position().x() < self.HEADER_WIDTH:
- new_view_start_sec = self.view_start_sec * (old_pps / new_pps)
+ new_view_start_ms = self.view_start_ms * (old_pps / new_pps)
else:
mouse_x = event.position().x()
- time_at_cursor = self.x_to_sec(mouse_x)
- new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps
+ time_at_cursor = self.x_to_ms(mouse_x)
+ new_view_start_ms = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps
- self.pixels_per_second = new_pps
- self.view_start_sec = max(0.0, new_view_start_sec)
+ self.pixels_per_ms = new_pps
+ self.view_start_ms = int(max(0, new_view_start_ms))
self.update()
event.accept()
@@ -684,7 +686,7 @@ def mousePressEvent(self, event: QMouseEvent):
if event.button() == Qt.MouseButton.MiddleButton:
self.panning = True
self.pan_start_pos = event.pos()
- self.pan_start_view_sec = self.view_start_sec
+ self.pan_start_view_ms = self.view_start_ms
self.setCursor(Qt.CursorShape.ClosedHandCursor)
event.accept()
return
@@ -713,8 +715,8 @@ def mousePressEvent(self, event: QMouseEvent):
for region in self.selection_regions:
if not region: continue
- x_start = self.sec_to_x(region[0])
- x_end = self.sec_to_x(region[1])
+ x_start = self.ms_to_x(region[0])
+ x_end = self.ms_to_x(region[1])
if event.pos().y() > self.TIMESCALE_HEIGHT:
if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH:
self.resizing_selection_region = region
@@ -768,12 +770,12 @@ def mousePressEvent(self, event: QMouseEvent):
if clicked_clip.id in self.selected_clips:
self.dragging_clip = clicked_clip
self.drag_start_state = self.window()._get_current_timeline_state()
- self.drag_original_clip_states[clicked_clip.id] = (clicked_clip.timeline_start_sec, clicked_clip.track_index)
+ self.drag_original_clip_states[clicked_clip.id] = (clicked_clip.timeline_start_ms, clicked_clip.track_index)
self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clicked_clip.group_id and c.id != clicked_clip.id), None)
if self.dragging_linked_clip:
self.drag_original_clip_states[self.dragging_linked_clip.id] = \
- (self.dragging_linked_clip.timeline_start_sec, self.dragging_linked_clip.track_index)
+ (self.dragging_linked_clip.timeline_start_ms, self.dragging_linked_clip.track_index)
self.drag_start_pos = event.pos()
else:
@@ -789,16 +791,16 @@ def mousePressEvent(self, event: QMouseEvent):
if is_in_track_area:
self.creating_selection_region = True
- playhead_x = self.sec_to_x(self.playhead_pos_sec)
+ playhead_x = self.ms_to_x(self.playhead_pos_ms)
if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS:
- self.selection_drag_start_sec = self.playhead_pos_sec
+ self.selection_drag_start_ms = self.playhead_pos_ms
else:
- self.selection_drag_start_sec = self.x_to_sec(event.pos().x())
- self.selection_regions.append([self.selection_drag_start_sec, self.selection_drag_start_sec])
+ self.selection_drag_start_ms = self.x_to_ms(event.pos().x())
+ self.selection_regions.append([self.selection_drag_start_ms, self.selection_drag_start_ms])
elif is_on_timescale:
- time_sec = max(0, self.x_to_sec(event.pos().x()))
- self.playhead_pos_sec = self._snap_time_if_needed(time_sec)
- self.playhead_moved.emit(self.playhead_pos_sec)
+ time_ms = max(0, self.x_to_ms(event.pos().x()))
+ self.playhead_pos_ms = self._snap_time_if_needed(time_ms)
+ self.playhead_moved.emit(self.playhead_pos_ms)
self.dragging_playhead = True
self.update()
@@ -806,22 +808,22 @@ def mousePressEvent(self, event: QMouseEvent):
def mouseMoveEvent(self, event: QMouseEvent):
if self.panning:
delta_x = event.pos().x() - self.pan_start_pos.x()
- time_delta = delta_x / self.pixels_per_second
- new_view_start = self.pan_start_view_sec - time_delta
- self.view_start_sec = max(0.0, new_view_start)
+ time_delta = delta_x / self.pixels_per_ms
+ new_view_start = self.pan_start_view_ms - time_delta
+ self.view_start_ms = int(max(0, new_view_start))
self.update()
return
if self.resizing_selection_region:
- current_sec = max(0, self.x_to_sec(event.pos().x()))
+ current_ms = max(0, self.x_to_ms(event.pos().x()))
original_start, original_end = self.resize_selection_start_values
if self.resize_selection_edge == 'left':
- new_start = current_sec
+ new_start = current_ms
new_end = original_end
else: # right
new_start = original_start
- new_end = current_sec
+ new_end = current_ms
self.resizing_selection_region[0] = min(new_start, new_end)
self.resizing_selection_region[1] = max(new_start, new_end)
@@ -836,60 +838,60 @@ def mouseMoveEvent(self, event: QMouseEvent):
if self.resizing_clip:
linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None)
delta_x = event.pos().x() - self.resize_start_pos.x()
- time_delta = delta_x / self.pixels_per_second
- min_duration = 1.0 / self.project_fps
- snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second
+ time_delta = delta_x / self.pixels_per_ms
+ min_duration_ms = int(1000 / self.project_fps)
+ snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms
- snap_points = [self.playhead_pos_sec]
+ snap_points = [self.playhead_pos_ms]
for clip in self.timeline.clips:
if clip.id == self.resizing_clip.id: continue
if linked_clip and clip.id == linked_clip.id: continue
- snap_points.append(clip.timeline_start_sec)
- snap_points.append(clip.timeline_end_sec)
+ snap_points.append(clip.timeline_start_ms)
+ snap_points.append(clip.timeline_end_ms)
media_props = self.window().media_properties.get(self.resizing_clip.source_path)
- source_duration = media_props['duration'] if media_props else float('inf')
+ source_duration_ms = media_props['duration_ms'] if media_props else float('inf')
if self.resize_edge == 'left':
- original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec
- original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec
- original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec
- true_new_start_sec = original_start + time_delta
+ original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_ms
+ original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_ms
+ original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_ms
+ true_new_start_ms = original_start + time_delta
- new_start_sec = true_new_start_sec
+ new_start_ms = true_new_start_ms
for snap_point in snap_points:
- if abs(true_new_start_sec - snap_point) < snap_time_delta:
- new_start_sec = snap_point
+ if abs(true_new_start_ms - snap_point) < snap_time_delta:
+ new_start_ms = snap_point
break
- if new_start_sec > original_start + original_duration - min_duration:
- new_start_sec = original_start + original_duration - min_duration
+ if new_start_ms > original_start + original_duration - min_duration_ms:
+ new_start_ms = original_start + original_duration - min_duration_ms
- new_start_sec = max(0, new_start_sec)
+ new_start_ms = max(0, new_start_ms)
if self.resizing_clip.media_type != 'image':
- if new_start_sec < original_start - original_clip_start:
- new_start_sec = original_start - original_clip_start
+ if new_start_ms < original_start - original_clip_start:
+ new_start_ms = original_start - original_clip_start
- new_duration = (original_start + original_duration) - new_start_sec
- new_clip_start = original_clip_start + (new_start_sec - original_start)
+ new_duration = (original_start + original_duration) - new_start_ms
+ new_clip_start = original_clip_start + (new_start_ms - original_start)
- if new_duration < min_duration:
- new_duration = min_duration
- new_start_sec = (original_start + original_duration) - new_duration
- new_clip_start = original_clip_start + (new_start_sec - original_start)
-
- self.resizing_clip.timeline_start_sec = new_start_sec
- self.resizing_clip.duration_sec = new_duration
- self.resizing_clip.clip_start_sec = new_clip_start
+ if new_duration < min_duration_ms:
+ new_duration = min_duration_ms
+ new_start_ms = (original_start + original_duration) - new_duration
+ new_clip_start = original_clip_start + (new_start_ms - original_start)
+
+ self.resizing_clip.timeline_start_ms = int(new_start_ms)
+ self.resizing_clip.duration_ms = int(new_duration)
+ self.resizing_clip.clip_start_ms = int(new_clip_start)
if linked_clip:
- linked_clip.timeline_start_sec = new_start_sec
- linked_clip.duration_sec = new_duration
- linked_clip.clip_start_sec = new_clip_start
+ linked_clip.timeline_start_ms = int(new_start_ms)
+ linked_clip.duration_ms = int(new_duration)
+ linked_clip.clip_start_ms = int(new_clip_start)
elif self.resize_edge == 'right':
- original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec
- original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec
+ original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_ms
+ original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_ms
true_new_duration = original_duration + time_delta
true_new_end_time = original_start + true_new_duration
@@ -902,23 +904,23 @@ def mouseMoveEvent(self, event: QMouseEvent):
new_duration = new_end_time - original_start
- if new_duration < min_duration:
- new_duration = min_duration
+ if new_duration < min_duration_ms:
+ new_duration = min_duration_ms
if self.resizing_clip.media_type != 'image':
- if self.resizing_clip.clip_start_sec + new_duration > source_duration:
- new_duration = source_duration - self.resizing_clip.clip_start_sec
+ if self.resizing_clip.clip_start_ms + new_duration > source_duration_ms:
+ new_duration = source_duration_ms - self.resizing_clip.clip_start_ms
- self.resizing_clip.duration_sec = new_duration
+ self.resizing_clip.duration_ms = int(new_duration)
if linked_clip:
- linked_clip.duration_sec = new_duration
+ linked_clip.duration_ms = int(new_duration)
self.update()
return
if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region:
cursor_set = False
- playhead_x = self.sec_to_x(self.playhead_pos_sec)
+ playhead_x = self.ms_to_x(self.playhead_pos_ms)
is_in_track_area = event.pos().y() > self.TIMESCALE_HEIGHT and event.pos().x() > self.HEADER_WIDTH
if is_in_track_area and abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS:
self.setCursor(Qt.CursorShape.SizeHorCursor)
@@ -926,8 +928,8 @@ def mouseMoveEvent(self, event: QMouseEvent):
if not cursor_set and is_in_track_area:
for region in self.selection_regions:
- x_start = self.sec_to_x(region[0])
- x_end = self.sec_to_x(region[1])
+ x_start = self.ms_to_x(region[0])
+ x_end = self.ms_to_x(region[1])
if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH or \
abs(event.pos().x() - x_end) < self.RESIZE_HANDLE_WIDTH:
self.setCursor(Qt.CursorShape.SizeHorCursor)
@@ -945,16 +947,16 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.unsetCursor()
if self.creating_selection_region:
- current_sec = self.x_to_sec(event.pos().x())
- start = min(self.selection_drag_start_sec, current_sec)
- end = max(self.selection_drag_start_sec, current_sec)
+ current_ms = self.x_to_ms(event.pos().x())
+ start = min(self.selection_drag_start_ms, current_ms)
+ end = max(self.selection_drag_start_ms, current_ms)
self.selection_regions[-1] = [start, end]
self.update()
return
if self.dragging_selection_region:
delta_x = event.pos().x() - self.drag_start_pos.x()
- time_delta = delta_x / self.pixels_per_second
+ time_delta = int(delta_x / self.pixels_per_ms)
original_start, original_end = self.drag_selection_start_values
duration = original_end - original_start
@@ -967,16 +969,16 @@ def mouseMoveEvent(self, event: QMouseEvent):
return
if self.dragging_playhead:
- time_sec = max(0, self.x_to_sec(event.pos().x()))
- self.playhead_pos_sec = self._snap_time_if_needed(time_sec)
- self.playhead_moved.emit(self.playhead_pos_sec)
+ time_ms = max(0, self.x_to_ms(event.pos().x()))
+ self.playhead_pos_ms = self._snap_time_if_needed(time_ms)
+ self.playhead_moved.emit(self.playhead_pos_ms)
self.update()
elif self.dragging_clip:
self.highlighted_track_info = None
self.highlighted_ghost_track_info = None
new_track_info = self.y_to_track_info(event.pos().y())
- original_start_sec, _ = self.drag_original_clip_states[self.dragging_clip.id]
+ original_start_ms, _ = self.drag_original_clip_states[self.dragging_clip.id]
if new_track_info:
new_track_type, new_track_index = new_track_info
@@ -993,19 +995,19 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.dragging_clip.track_index = new_track_index
delta_x = event.pos().x() - self.drag_start_pos.x()
- time_delta = delta_x / self.pixels_per_second
- true_new_start_time = original_start_sec + time_delta
+ time_delta = delta_x / self.pixels_per_ms
+ true_new_start_time = original_start_ms + time_delta
- playhead_time = self.playhead_pos_sec
- snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second
+ playhead_time = self.playhead_pos_ms
+ snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms
new_start_time = true_new_start_time
- true_new_end_time = true_new_start_time + self.dragging_clip.duration_sec
+ true_new_end_time = true_new_start_time + self.dragging_clip.duration_ms
if abs(true_new_start_time - playhead_time) < snap_time_delta:
new_start_time = playhead_time
elif abs(true_new_end_time - playhead_time) < snap_time_delta:
- new_start_time = playhead_time - self.dragging_clip.duration_sec
+ new_start_time = playhead_time - self.dragging_clip.duration_ms
for other_clip in self.timeline.clips:
if other_clip.id == self.dragging_clip.id: continue
@@ -1014,21 +1016,21 @@ def mouseMoveEvent(self, event: QMouseEvent):
other_clip.track_index != self.dragging_clip.track_index):
continue
- is_overlapping = (new_start_time < other_clip.timeline_end_sec and
- new_start_time + self.dragging_clip.duration_sec > other_clip.timeline_start_sec)
+ is_overlapping = (new_start_time < other_clip.timeline_end_ms and
+ new_start_time + self.dragging_clip.duration_ms > other_clip.timeline_start_ms)
if is_overlapping:
- movement_direction = true_new_start_time - original_start_sec
+ movement_direction = true_new_start_time - original_start_ms
if movement_direction > 0:
- new_start_time = other_clip.timeline_start_sec - self.dragging_clip.duration_sec
+ new_start_time = other_clip.timeline_start_ms - self.dragging_clip.duration_ms
else:
- new_start_time = other_clip.timeline_end_sec
+ new_start_time = other_clip.timeline_end_ms
break
final_start_time = max(0, new_start_time)
- self.dragging_clip.timeline_start_sec = final_start_time
+ self.dragging_clip.timeline_start_ms = int(final_start_time)
if self.dragging_linked_clip:
- self.dragging_linked_clip.timeline_start_sec = final_start_time
+ self.dragging_linked_clip.timeline_start_ms = int(final_start_time)
self.update()
@@ -1061,7 +1063,7 @@ def mouseReleaseEvent(self, event: QMouseEvent):
self.creating_selection_region = False
if self.selection_regions:
start, end = self.selection_regions[-1]
- if (end - start) * self.pixels_per_second < 2:
+ if (end - start) * self.pixels_per_ms < 2:
self.clear_all_regions()
if self.dragging_selection_region:
@@ -1071,18 +1073,18 @@ def mouseReleaseEvent(self, event: QMouseEvent):
self.dragging_playhead = False
if self.dragging_clip:
orig_start, orig_track = self.drag_original_clip_states[self.dragging_clip.id]
- moved = (orig_start != self.dragging_clip.timeline_start_sec or
+ moved = (orig_start != self.dragging_clip.timeline_start_ms or
orig_track != self.dragging_clip.track_index)
if self.dragging_linked_clip:
orig_start_link, orig_track_link = self.drag_original_clip_states[self.dragging_linked_clip.id]
- moved = moved or (orig_start_link != self.dragging_linked_clip.timeline_start_sec or
+ moved = moved or (orig_start_link != self.dragging_linked_clip.timeline_start_ms or
orig_track_link != self.dragging_linked_clip.track_index)
if moved:
self.window().finalize_clip_drag(self.drag_start_state)
- self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_ms)
self.highlighted_track_info = None
self.highlighted_ghost_track_info = None
self.operation_finished.emit()
@@ -1162,12 +1164,12 @@ def dragMoveEvent(self, event):
event.acceptProposedAction()
- duration = media_props['duration']
+ duration_ms = media_props['duration_ms']
media_type = media_props['media_type']
has_audio = media_props['has_audio']
pos = event.position()
- start_sec = self.x_to_sec(pos.x())
+ start_ms = self.x_to_ms(pos.x())
track_info = self.y_to_track_info(pos.y())
self.drag_over_rect = QRectF()
@@ -1187,8 +1189,8 @@ def dragMoveEvent(self, event):
else:
self.highlighted_track_info = track_info
- width = int(duration * self.pixels_per_second)
- x = self.sec_to_x(start_sec)
+ width = int(duration_ms * self.pixels_per_ms)
+ x = self.ms_to_x(start_ms)
video_y, audio_y = -1, -1
@@ -1232,19 +1234,19 @@ def dropEvent(self, event):
return
pos = event.position()
- start_sec = self.x_to_sec(pos.x())
+ start_ms = self.x_to_ms(pos.x())
track_info = self.y_to_track_info(pos.y())
if not track_info:
event.ignore()
return
- current_timeline_pos = start_sec
+ current_timeline_pos = start_ms
for file_path in added_files:
media_info = main_window.media_properties.get(file_path)
if not media_info: continue
- duration = media_info['duration']
+ duration_ms = media_info['duration_ms']
has_audio = media_info['has_audio']
media_type = media_info['media_type']
@@ -1269,14 +1271,14 @@ def dropEvent(self, event):
main_window._add_clip_to_timeline(
source_path=file_path,
- timeline_start_sec=current_timeline_pos,
- duration_sec=duration,
+ timeline_start_ms=current_timeline_pos,
+ duration_ms=duration_ms,
media_type=media_type,
- clip_start_sec=0,
+ clip_start_ms=0,
video_track_index=video_track_idx,
audio_track_index=audio_track_idx
)
- current_timeline_pos += duration
+ current_timeline_pos += duration_ms
event.acceptProposedAction()
return
@@ -1286,12 +1288,12 @@ def dropEvent(self, event):
json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8'))
file_path = json_data['path']
- duration = json_data['duration']
+ duration_ms = json_data['duration_ms']
has_audio = json_data['has_audio']
media_type = json_data['media_type']
pos = event.position()
- start_sec = self.x_to_sec(pos.x())
+ start_ms = self.x_to_ms(pos.x())
track_info = self.y_to_track_info(pos.y())
if not track_info:
@@ -1321,10 +1323,10 @@ def dropEvent(self, event):
main_window = self.window()
main_window._add_clip_to_timeline(
source_path=file_path,
- timeline_start_sec=start_sec,
- duration_sec=duration,
+ timeline_start_ms=start_ms,
+ duration_ms=duration_ms,
media_type=media_type,
- clip_start_sec=0,
+ clip_start_ms=0,
video_track_index=video_track_idx,
audio_track_index=audio_track_idx
)
@@ -1382,8 +1384,8 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'):
split_action = menu.addAction("Split Clip")
delete_action = menu.addAction("Delete Clip")
- playhead_time = self.playhead_pos_sec
- is_playhead_over_clip = (clip_at_pos.timeline_start_sec < playhead_time < clip_at_pos.timeline_end_sec)
+ playhead_time = self.playhead_pos_ms
+ is_playhead_over_clip = (clip_at_pos.timeline_start_ms < playhead_time < clip_at_pos.timeline_end_ms)
split_action.setEnabled(is_playhead_over_clip)
split_action.triggered.connect(lambda: self.split_requested.emit(clip_at_pos))
delete_action.triggered.connect(lambda: self.delete_clip_requested.emit(clip_at_pos))
@@ -1672,7 +1674,7 @@ def startDrag(self, supportedActions):
payload = {
"path": path,
- "duration": media_info['duration'],
+ "duration_ms": media_info['duration_ms'],
"has_audio": media_info['has_audio'],
"media_type": media_info['media_type']
}
@@ -1930,7 +1932,7 @@ def _toggle_scale_to_fit(self, checked):
def on_timeline_changed_by_undo(self):
self.prune_empty_tracks()
self.timeline_widget.update()
- self.seek_preview(self.timeline_widget.playhead_pos_sec)
+ self.seek_preview(self.timeline_widget.playhead_pos_ms)
self.status_label.setText("Operation undone/redone.")
def update_undo_redo_actions(self):
@@ -1964,8 +1966,8 @@ def on_add_to_timeline_at_playhead(self, file_path):
self.status_label.setText(f"Error: Could not find properties for {os.path.basename(file_path)}")
return
- playhead_pos = self.timeline_widget.playhead_pos_sec
- duration = media_info['duration']
+ playhead_pos = self.timeline_widget.playhead_pos_ms
+ duration_ms = media_info['duration_ms']
has_audio = media_info['has_audio']
media_type = media_info['media_type']
@@ -1974,10 +1976,10 @@ def on_add_to_timeline_at_playhead(self, file_path):
self._add_clip_to_timeline(
source_path=file_path,
- timeline_start_sec=playhead_pos,
- duration_sec=duration,
+ timeline_start_ms=playhead_pos,
+ duration_ms=duration_ms,
media_type=media_type,
- clip_start_sec=0.0,
+ clip_start_ms=0,
video_track_index=video_track,
audio_track_index=audio_track
)
@@ -2102,25 +2104,25 @@ def _create_menu_bar(self):
data['action'] = action
self.windows_menu.addAction(action)
- def _get_topmost_video_clip_at(self, time_sec):
+ def _get_topmost_video_clip_at(self, time_ms):
"""Finds the video clip on the highest track at a specific time."""
top_clip = None
for c in self.timeline.clips:
- if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec:
+ if c.track_type == 'video' and c.timeline_start_ms <= time_ms < c.timeline_end_ms:
if top_clip is None or c.track_index > top_clip.track_index:
top_clip = c
return top_clip
- def _start_playback_stream_at(self, time_sec):
+ def _start_playback_stream_at(self, time_ms):
self._stop_playback_stream()
- clip = self._get_topmost_video_clip_at(time_sec)
+ clip = self._get_topmost_video_clip_at(time_ms)
if not clip: return
self.playback_clip = clip
- clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec
+ clip_time_sec = (time_ms - clip.timeline_start_ms + clip.clip_start_ms) / 1000.0
w, h = self.project_width, self.project_height
try:
- args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time)
+ args = (ffmpeg.input(self.playback_clip.source_path, ss=f"{clip_time_sec:.6f}")
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile())
@@ -2149,7 +2151,7 @@ def _get_media_properties(self, file_path):
if file_ext in ['.png', '.jpg', '.jpeg']:
img = QImage(file_path)
media_info['media_type'] = 'image'
- media_info['duration'] = 5.0
+ media_info['duration_ms'] = 5000
media_info['has_audio'] = False
media_info['width'] = img.width()
media_info['height'] = img.height()
@@ -2160,7 +2162,8 @@ def _get_media_properties(self, file_path):
if video_stream:
media_info['media_type'] = 'video'
- media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0)))
+ duration_sec = float(video_stream.get('duration', probe['format'].get('duration', 0)))
+ media_info['duration_ms'] = int(duration_sec * 1000)
media_info['has_audio'] = audio_stream is not None
media_info['width'] = int(video_stream['width'])
media_info['height'] = int(video_stream['height'])
@@ -2169,7 +2172,8 @@ def _get_media_properties(self, file_path):
if den > 0: media_info['fps'] = num / den
elif audio_stream:
media_info['media_type'] = 'audio'
- media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0)))
+ duration_sec = float(audio_stream.get('duration', probe['format'].get('duration', 0)))
+ media_info['duration_ms'] = int(duration_sec * 1000)
media_info['has_audio'] = True
else:
return None
@@ -2215,8 +2219,8 @@ def _update_project_properties_from_clip(self, source_path):
def _probe_for_drag(self, file_path):
return self._get_media_properties(file_path)
- def get_frame_data_at_time(self, time_sec):
- clip_at_time = self._get_topmost_video_clip_at(time_sec)
+ def get_frame_data_at_time(self, time_ms):
+ clip_at_time = self._get_topmost_video_clip_at(time_ms)
if not clip_at_time:
return (None, 0, 0)
try:
@@ -2231,10 +2235,10 @@ def get_frame_data_at_time(self, time_sec):
.run(capture_stdout=True, quiet=True)
)
else:
- clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
+ clip_time_sec = (time_ms - clip_at_time.timeline_start_ms + clip_at_time.clip_start_ms) / 1000.0
out, _ = (
ffmpeg
- .input(clip_at_time.source_path, ss=clip_time)
+ .input(clip_at_time.source_path, ss=f"{clip_time_sec:.6f}")
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
@@ -2245,8 +2249,8 @@ def get_frame_data_at_time(self, time_sec):
print(f"Error extracting frame data: {e.stderr}")
return (None, 0, 0)
- def get_frame_at_time(self, time_sec):
- clip_at_time = self._get_topmost_video_clip_at(time_sec)
+ def get_frame_at_time(self, time_ms):
+ clip_at_time = self._get_topmost_video_clip_at(time_ms)
if not clip_at_time: return None
try:
w, h = self.project_width, self.project_height
@@ -2259,8 +2263,8 @@ def get_frame_at_time(self, time_sec):
.run(capture_stdout=True, quiet=True)
)
else:
- clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec
- out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time)
+ clip_time_sec = (time_ms - clip_at_time.timeline_start_ms + clip_at_time.clip_start_ms) / 1000.0
+ out, _ = (ffmpeg.input(clip_at_time.source_path, ss=f"{clip_time_sec:.6f}")
.filter('scale', w, h, force_original_aspect_ratio='decrease')
.filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
.output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
@@ -2270,11 +2274,11 @@ def get_frame_at_time(self, time_sec):
return QPixmap.fromImage(image)
except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return None
- def seek_preview(self, time_sec):
+ def seek_preview(self, time_ms):
self._stop_playback_stream()
- self.timeline_widget.playhead_pos_sec = time_sec
+ self.timeline_widget.playhead_pos_ms = int(time_ms)
self.timeline_widget.update()
- self.current_preview_pixmap = self.get_frame_at_time(time_sec)
+ self.current_preview_pixmap = self.get_frame_at_time(time_ms)
self._update_preview_display()
def _update_preview_display(self):
@@ -2303,26 +2307,26 @@ def toggle_playback(self):
if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play")
else:
if not self.timeline.clips: return
- if self.timeline_widget.playhead_pos_sec >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_sec = 0.0
+ if self.timeline_widget.playhead_pos_ms >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_ms = 0
self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause")
- def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0.0)
+ def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0)
def step_frame(self, direction):
if not self.timeline.clips: return
self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream()
- frame_duration = 1.0 / self.project_fps
- new_time = self.timeline_widget.playhead_pos_sec + (direction * frame_duration)
- self.seek_preview(max(0, min(new_time, self.timeline.get_total_duration())))
+ frame_duration_ms = 1000.0 / self.project_fps
+ new_time = self.timeline_widget.playhead_pos_ms + (direction * frame_duration_ms)
+ self.seek_preview(int(max(0, min(new_time, self.timeline.get_total_duration()))))
def advance_playback_frame(self):
- frame_duration = 1.0 / self.project_fps
- new_time = self.timeline_widget.playhead_pos_sec + frame_duration
- if new_time > self.timeline.get_total_duration(): self.stop_playback(); return
+ frame_duration_ms = 1000.0 / self.project_fps
+ new_time_ms = self.timeline_widget.playhead_pos_ms + frame_duration_ms
+ if new_time_ms > self.timeline.get_total_duration(): self.stop_playback(); return
- self.timeline_widget.playhead_pos_sec = new_time
+ self.timeline_widget.playhead_pos_ms = round(new_time_ms)
self.timeline_widget.update()
- clip_at_new_time = self._get_topmost_video_clip_at(new_time)
+ clip_at_new_time = self._get_topmost_video_clip_at(new_time_ms)
if not clip_at_new_time:
self._stop_playback_stream()
@@ -2334,12 +2338,12 @@ def advance_playback_frame(self):
if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
self._stop_playback_stream()
self.playback_clip = clip_at_new_time
- self.current_preview_pixmap = self.get_frame_at_time(new_time)
+ self.current_preview_pixmap = self.get_frame_at_time(new_time_ms)
self._update_preview_display()
return
if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
- self._start_playback_stream_at(new_time)
+ self._start_playback_stream_at(new_time_ms)
if self.playback_process:
frame_size = self.project_width * self.project_height * 3
@@ -2421,7 +2425,7 @@ def save_project_as(self):
if not path: return
project_data = {
"media_pool": self.media_pool,
- "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips],
+ "clips": [{"source_path": c.source_path, "timeline_start_ms": c.timeline_start_ms, "clip_start_ms": c.clip_start_ms, "duration_ms": c.duration_ms, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips],
"selection_regions": self.timeline_widget.selection_regions,
"last_export_path": self.last_export_path,
"settings": {
@@ -2463,7 +2467,7 @@ def _load_project_from_path(self, path):
for clip_data in project_data["clips"]:
if not os.path.exists(clip_data["source_path"]):
self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.new_project(); return
-
+
if 'media_type' not in clip_data:
ext = os.path.splitext(clip_data['source_path'])[1].lower()
if ext in ['.mp3', '.wav', '.m4a', '.aac']:
@@ -2552,19 +2556,19 @@ def add_media_to_timeline(self):
if not added_files:
return
- playhead_pos = self.timeline_widget.playhead_pos_sec
+ playhead_pos = self.timeline_widget.playhead_pos_ms
def add_clips_action():
for file_path in added_files:
media_info = self.media_properties.get(file_path)
if not media_info: continue
- duration = media_info['duration']
+ duration_ms = media_info['duration_ms']
media_type = media_info['media_type']
has_audio = media_info['has_audio']
clip_start_time = playhead_pos
- clip_end_time = playhead_pos + duration
+ clip_end_time = playhead_pos + duration_ms
video_track_index = None
audio_track_index = None
@@ -2572,7 +2576,7 @@ def add_clips_action():
if media_type in ['video', 'image']:
for i in range(1, self.timeline.num_video_tracks + 2):
is_occupied = any(
- c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time
+ c.timeline_start_ms < clip_end_time and c.timeline_end_ms > clip_start_time
for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i
)
if not is_occupied:
@@ -2582,7 +2586,7 @@ def add_clips_action():
if has_audio:
for i in range(1, self.timeline.num_audio_tracks + 2):
is_occupied = any(
- c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time
+ c.timeline_start_ms < clip_end_time and c.timeline_end_ms > clip_start_time
for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i
)
if not is_occupied:
@@ -2593,13 +2597,13 @@ def add_clips_action():
if video_track_index is not None:
if video_track_index > self.timeline.num_video_tracks:
self.timeline.num_video_tracks = video_track_index
- video_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, video_track_index, 'video', media_type, group_id)
+ video_clip = TimelineClip(file_path, clip_start_time, 0, duration_ms, video_track_index, 'video', media_type, group_id)
self.timeline.add_clip(video_clip)
if audio_track_index is not None:
if audio_track_index > self.timeline.num_audio_tracks:
self.timeline.num_audio_tracks = audio_track_index
- audio_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, audio_track_index, 'audio', media_type, group_id)
+ audio_clip = TimelineClip(file_path, clip_start_time, 0, duration_ms, audio_track_index, 'audio', media_type, group_id)
self.timeline.add_clip(audio_clip)
self.status_label.setText(f"Added {len(added_files)} file(s) to timeline.")
@@ -2611,7 +2615,7 @@ def add_media_files(self):
if file_paths:
self._add_media_files_to_project(file_paths)
- def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None):
+ def _add_clip_to_timeline(self, source_path, timeline_start_ms, duration_ms, media_type, clip_start_ms=0, video_track_index=None, audio_track_index=None):
if media_type in ['video', 'image']:
self._update_project_properties_from_clip(source_path)
@@ -2621,13 +2625,13 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, m
if video_track_index is not None:
if video_track_index > self.timeline.num_video_tracks:
self.timeline.num_video_tracks = video_track_index
- video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', media_type, group_id)
+ video_clip = TimelineClip(source_path, timeline_start_ms, clip_start_ms, duration_ms, video_track_index, 'video', media_type, group_id)
self.timeline.add_clip(video_clip)
if audio_track_index is not None:
if audio_track_index > self.timeline.num_audio_tracks:
self.timeline.num_audio_tracks = audio_track_index
- audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', media_type, group_id)
+ audio_clip = TimelineClip(source_path, timeline_start_ms, clip_start_ms, duration_ms, audio_track_index, 'audio', media_type, group_id)
self.timeline.add_clip(audio_clip)
new_state = self._get_current_timeline_state()
@@ -2635,21 +2639,21 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, m
command.undo()
self.undo_stack.push(command)
- def _split_at_time(self, clip_to_split, time_sec, new_group_id=None):
- if not (clip_to_split.timeline_start_sec < time_sec < clip_to_split.timeline_end_sec): return False
- split_point = time_sec - clip_to_split.timeline_start_sec
- orig_dur = clip_to_split.duration_sec
+ def _split_at_time(self, clip_to_split, time_ms, new_group_id=None):
+ if not (clip_to_split.timeline_start_ms < time_ms < clip_to_split.timeline_end_ms): return False
+ split_point = time_ms - clip_to_split.timeline_start_ms
+ orig_dur = clip_to_split.duration_ms
group_id_for_new_clip = new_group_id if new_group_id is not None else clip_to_split.group_id
- new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, clip_to_split.media_type, group_id_for_new_clip)
- clip_to_split.duration_sec = split_point
+ new_clip = TimelineClip(clip_to_split.source_path, time_ms, clip_to_split.clip_start_ms + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, clip_to_split.media_type, group_id_for_new_clip)
+ clip_to_split.duration_ms = split_point
self.timeline.add_clip(new_clip)
return True
def split_clip_at_playhead(self, clip_to_split=None):
- playhead_time = self.timeline_widget.playhead_pos_sec
+ playhead_time = self.timeline_widget.playhead_pos_ms
if not clip_to_split:
- clips_at_playhead = [c for c in self.timeline.clips if c.timeline_start_sec < playhead_time < c.timeline_end_sec]
+ clips_at_playhead = [c for c in self.timeline.clips if c.timeline_start_ms < playhead_time < c.timeline_end_ms]
if not clips_at_playhead:
self.status_label.setText("Playhead is not over a clip to split.")
return
@@ -2723,20 +2727,20 @@ def action():
return
target_audio_track = 1
- new_audio_start = video_clip.timeline_start_sec
- new_audio_end = video_clip.timeline_end_sec
+ new_audio_start = video_clip.timeline_start_ms
+ new_audio_end = video_clip.timeline_end_ms
conflicting_clips = [
c for c in self.timeline.clips
if c.track_type == 'audio' and c.track_index == target_audio_track and
- c.timeline_start_sec < new_audio_end and c.timeline_end_sec > new_audio_start
+ c.timeline_start_ms < new_audio_end and c.timeline_end_ms > new_audio_start
]
for conflict_clip in conflicting_clips:
found_spot = False
for check_track_idx in range(target_audio_track + 1, self.timeline.num_audio_tracks + 2):
is_occupied = any(
- other.timeline_start_sec < conflict_clip.timeline_end_sec and other.timeline_end_sec > conflict_clip.timeline_start_sec
+ other.timeline_start_ms < conflict_clip.timeline_end_ms and other.timeline_end_ms > conflict_clip.timeline_start_ms
for other in self.timeline.clips
if other.id != conflict_clip.id and other.track_type == 'audio' and other.track_index == check_track_idx
)
@@ -2751,9 +2755,9 @@ def action():
new_audio_clip = TimelineClip(
source_path=video_clip.source_path,
- timeline_start_sec=video_clip.timeline_start_sec,
- clip_start_sec=video_clip.clip_start_sec,
- duration_sec=video_clip.duration_sec,
+ timeline_start_ms=video_clip.timeline_start_ms,
+ clip_start_ms=video_clip.clip_start_ms,
+ duration_ms=video_clip.duration_ms,
track_index=target_audio_track,
track_type='audio',
media_type=video_clip.media_type,
@@ -2778,10 +2782,10 @@ def _perform_complex_timeline_change(self, description, change_function):
def on_split_region(self, region):
def action():
- start_sec, end_sec = region
+ start_ms, end_ms = region
clips = list(self.timeline.clips)
- for clip in clips: self._split_at_time(clip, end_sec)
- for clip in clips: self._split_at_time(clip, start_sec)
+ for clip in clips: self._split_at_time(clip, end_ms)
+ for clip in clips: self._split_at_time(clip, start_ms)
self.timeline_widget.clear_region(region)
self._perform_complex_timeline_change("Split Region", action)
@@ -2793,7 +2797,7 @@ def action():
split_points.add(end)
for point in sorted(list(split_points)):
- group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids:
@@ -2803,98 +2807,98 @@ def action():
def on_join_region(self, region):
def action():
- start_sec, end_sec = region
- duration_to_remove = end_sec - start_sec
- if duration_to_remove <= 0.01: return
+ start_ms, end_ms = region
+ duration_to_remove = end_ms - start_ms
+ if duration_to_remove <= 10: return
- for point in [start_sec, end_sec]:
- group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ for point in [start_ms, end_ms]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
- clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
+ clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms]
for clip in clips_to_remove: self.timeline.clips.remove(clip)
for clip in self.timeline.clips:
- if clip.timeline_start_sec >= end_sec:
- clip.timeline_start_sec -= duration_to_remove
+ if clip.timeline_start_ms >= end_ms:
+ clip.timeline_start_ms -= duration_to_remove
- self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_ms)
self.timeline_widget.clear_region(region)
self._perform_complex_timeline_change("Join Region", action)
def on_join_all_regions(self, regions):
def action():
for region in sorted(regions, key=lambda r: r[0], reverse=True):
- start_sec, end_sec = region
- duration_to_remove = end_sec - start_sec
- if duration_to_remove <= 0.01: continue
+ start_ms, end_ms = region
+ duration_to_remove = end_ms - start_ms
+ if duration_to_remove <= 10: continue
- for point in [start_sec, end_sec]:
- group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ for point in [start_ms, end_ms]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
- clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
+ clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms]
for clip in clips_to_remove:
try: self.timeline.clips.remove(clip)
except ValueError: pass
for clip in self.timeline.clips:
- if clip.timeline_start_sec >= end_sec:
- clip.timeline_start_sec -= duration_to_remove
+ if clip.timeline_start_ms >= end_ms:
+ clip.timeline_start_ms -= duration_to_remove
- self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_ms)
self.timeline_widget.clear_all_regions()
self._perform_complex_timeline_change("Join All Regions", action)
def on_delete_region(self, region):
def action():
- start_sec, end_sec = region
- duration_to_remove = end_sec - start_sec
- if duration_to_remove <= 0.01: return
+ start_ms, end_ms = region
+ duration_to_remove = end_ms - start_ms
+ if duration_to_remove <= 10: return
- for point in [start_sec, end_sec]:
- group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ for point in [start_ms, end_ms]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
- clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
+ clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms]
for clip in clips_to_remove: self.timeline.clips.remove(clip)
for clip in self.timeline.clips:
- if clip.timeline_start_sec >= end_sec:
- clip.timeline_start_sec -= duration_to_remove
+ if clip.timeline_start_ms >= end_ms:
+ clip.timeline_start_ms -= duration_to_remove
- self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_ms)
self.timeline_widget.clear_region(region)
self._perform_complex_timeline_change("Delete Region", action)
def on_delete_all_regions(self, regions):
def action():
for region in sorted(regions, key=lambda r: r[0], reverse=True):
- start_sec, end_sec = region
- duration_to_remove = end_sec - start_sec
- if duration_to_remove <= 0.01: continue
+ start_ms, end_ms = region
+ duration_to_remove = end_ms - start_ms
+ if duration_to_remove <= 10: continue
- for point in [start_sec, end_sec]:
- group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec}
+ for point in [start_ms, end_ms]:
+ group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms}
new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point}
for clip in list(self.timeline.clips):
if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id])
- clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec]
+ clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms]
for clip in clips_to_remove:
try: self.timeline.clips.remove(clip)
except ValueError: pass
for clip in self.timeline.clips:
- if clip.timeline_start_sec >= end_sec:
- clip.timeline_start_sec -= duration_to_remove
+ if clip.timeline_start_ms >= end_ms:
+ clip.timeline_start_ms -= duration_to_remove
- self.timeline.clips.sort(key=lambda c: c.timeline_start_sec)
+ self.timeline.clips.sort(key=lambda c: c.timeline_start_ms)
self.timeline_widget.clear_all_regions()
self._perform_complex_timeline_change("Delete All Regions", action)
@@ -2936,10 +2940,12 @@ def export_video(self):
self.last_export_path = output_path
- w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration()
+ total_dur_ms = self.timeline.get_total_duration()
+ total_dur_sec = total_dur_ms / 1000.0
+ w, h, fr_str = self.project_width, self.project_height, str(self.project_fps)
sample_rate, channel_layout = '44100', 'stereo'
- video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi')
+ video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur_sec}', f='lavfi')
all_video_clips = sorted(
[c for c in self.timeline.clips if c.track_type == 'video'],
@@ -2954,37 +2960,45 @@ def export_video(self):
input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
clip_source_node = input_nodes[clip.source_path]
+ clip_duration_sec = clip.duration_ms / 1000.0
+ clip_start_sec = clip.clip_start_ms / 1000.0
+ timeline_start_sec = clip.timeline_start_ms / 1000.0
+ timeline_end_sec = clip.timeline_end_ms / 1000.0
+
if clip.media_type == 'image':
- segment_stream = (clip_source_node.video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS'))
+ segment_stream = (clip_source_node.video.trim(duration=clip_duration_sec).setpts('PTS-STARTPTS'))
else:
- segment_stream = (clip_source_node.video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS'))
+ segment_stream = (clip_source_node.video.trim(start=clip_start_sec, duration=clip_duration_sec).setpts('PTS-STARTPTS'))
processed_segment = (segment_stream.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black'))
- video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})')
+ video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{timeline_start_sec},{timeline_end_sec})')
final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps)
track_audio_streams = []
for i in range(self.timeline.num_audio_tracks):
- track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec)
+ track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_ms)
if not track_clips: continue
track_segments = []
- last_end = track_clips[0].timeline_start_sec
+ last_end_ms = track_clips[0].timeline_start_ms
for clip in track_clips:
if clip.source_path not in input_nodes:
input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
- gap = clip.timeline_start_sec - last_end
- if gap > 0.01:
- track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi'))
- a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS')
+ gap_ms = clip.timeline_start_ms - last_end_ms
+ if gap_ms > 10:
+ track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap_ms/1000.0}', f='lavfi'))
+
+ clip_start_sec = clip.clip_start_ms / 1000.0
+ clip_duration_sec = clip.duration_ms / 1000.0
+ a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip_start_sec, duration=clip_duration_sec).filter('asetpts', 'PTS-STARTPTS')
track_segments.append(a_seg)
- last_end = clip.timeline_end_sec
- track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_sec * 1000)}ms', all=True))
+ last_end_ms = clip.timeline_end_ms
+ track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_ms)}ms', all=True))
if track_audio_streams:
final_audio = ffmpeg.filter(track_audio_streams, 'amix', inputs=len(track_audio_streams), duration='longest')
else:
- final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur}', f='lavfi')
+ final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur_sec}', f='lavfi')
output_args = {'vcodec': settings['vcodec'], 'acodec': settings['acodec'], 'pix_fmt': 'yuv420p'}
if settings['v_bitrate']: output_args['b:v'] = settings['v_bitrate']
@@ -2994,7 +3008,7 @@ def export_video(self):
ffmpeg_cmd = ffmpeg.output(final_video, final_audio, output_path, **output_args).overwrite_output().compile()
self.progress_bar.setVisible(True); self.progress_bar.setValue(0); self.status_label.setText("Exporting...")
self.export_thread = QThread()
- self.export_worker = ExportWorker(ffmpeg_cmd, total_dur)
+ self.export_worker = ExportWorker(ffmpeg_cmd, total_dur_ms)
self.export_worker.moveToThread(self.export_thread)
self.export_thread.started.connect(self.export_worker.run_export)
self.export_worker.finished.connect(self.on_export_finished)
From 75d29f9b48bec16b0e3468af28d37d52676a98be Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 21:35:16 +1100
Subject: [PATCH 142/155] fix timescale shifts
---
videoeditor.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/videoeditor.py b/videoeditor.py
index 5126b79b5..632ca4069 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -468,7 +468,7 @@ def draw_timescale(self, painter):
frame_dur_ms = 1000.0 / self.project_fps
intervals_ms = [
- int(frame_dur_ms), int(2*frame_dur_ms), int(5*frame_dur_ms), int(10*frame_dur_ms),
+ frame_dur_ms, 2*frame_dur_ms, 5*frame_dur_ms, 10*frame_dur_ms,
100, 200, 500, 1000, 2000, 5000, 10000, 15000, 30000,
60000, 120000, 300000, 600000, 900000, 1800000,
3600000, 2*3600000, 5*3600000, 10*3600000
From fde093f128f408b0b3827f0f0a562df8dcd9272c Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 21:50:44 +1100
Subject: [PATCH 143/155] add buttons to snap to nearest clip collision
---
videoeditor.py | 32 ++++++++++++++++++++++++++++++++
1 file changed, 32 insertions(+)
diff --git a/videoeditor.py b/videoeditor.py
index 632ca4069..98d908477 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -1853,11 +1853,15 @@ def _setup_ui(self):
self.stop_button = QPushButton("Stop")
self.frame_back_button = QPushButton("<")
self.frame_forward_button = QPushButton(">")
+ self.snap_back_button = QPushButton("|<")
+ self.snap_forward_button = QPushButton(">|")
controls_layout.addStretch()
+ controls_layout.addWidget(self.snap_back_button)
controls_layout.addWidget(self.frame_back_button)
controls_layout.addWidget(self.play_pause_button)
controls_layout.addWidget(self.stop_button)
controls_layout.addWidget(self.frame_forward_button)
+ controls_layout.addWidget(self.snap_forward_button)
controls_layout.addStretch()
main_layout.addWidget(controls_widget)
@@ -1908,6 +1912,8 @@ def _connect_signals(self):
self.stop_button.clicked.connect(self.stop_playback)
self.frame_back_button.clicked.connect(lambda: self.step_frame(-1))
self.frame_forward_button.clicked.connect(lambda: self.step_frame(1))
+ self.snap_back_button.clicked.connect(lambda: self.snap_playhead(-1))
+ self.snap_forward_button.clicked.connect(lambda: self.snap_playhead(1))
self.playback_timer.timeout.connect(self.advance_playback_frame)
self.project_media_widget.add_media_requested.connect(self.add_media_files)
@@ -2318,6 +2324,32 @@ def step_frame(self, direction):
new_time = self.timeline_widget.playhead_pos_ms + (direction * frame_duration_ms)
self.seek_preview(int(max(0, min(new_time, self.timeline.get_total_duration()))))
+ def snap_playhead(self, direction):
+ if not self.timeline.clips:
+ return
+
+ current_time_ms = self.timeline_widget.playhead_pos_ms
+
+ snap_points = set()
+ for clip in self.timeline.clips:
+ snap_points.add(clip.timeline_start_ms)
+ snap_points.add(clip.timeline_end_ms)
+
+ sorted_points = sorted(list(snap_points))
+
+ TOLERANCE_MS = 1
+
+ if direction == 1: # Forward
+ next_points = [p for p in sorted_points if p > current_time_ms + TOLERANCE_MS]
+ if next_points:
+ self.seek_preview(next_points[0])
+ elif direction == -1: # Backward
+ prev_points = [p for p in sorted_points if p < current_time_ms - TOLERANCE_MS]
+ if prev_points:
+ self.seek_preview(prev_points[-1])
+ elif current_time_ms > 0:
+ self.seek_preview(0)
+
def advance_playback_frame(self):
frame_duration_ms = 1000.0 / self.project_fps
new_time_ms = self.timeline_widget.playhead_pos_ms + frame_duration_ms
From 78f2a959ab12d56a24bcd1a51f916ee5057b01bc Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Sun, 12 Oct 2025 22:13:19 +1100
Subject: [PATCH 144/155] hold shift for universal frame by frame snapping on
all operations
---
videoeditor.py | 91 ++++++++++++++++++++++++++++++++++----------------
1 file changed, 63 insertions(+), 28 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index 98d908477..53f3baa85 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -636,11 +636,17 @@ def y_to_track_info(self, y):
return None
+ def _snap_to_frame(self, time_ms):
+ frame_duration_ms = 1000.0 / self.project_fps
+ if frame_duration_ms <= 0:
+ return int(time_ms)
+ frame_number = round(time_ms / frame_duration_ms)
+ return int(frame_number * frame_duration_ms)
+
def _snap_time_if_needed(self, time_ms):
frame_duration_ms = 1000.0 / self.project_fps
if frame_duration_ms > 0 and frame_duration_ms * self.pixels_per_ms > 4:
- frame_number = round(time_ms / frame_duration_ms)
- return int(frame_number * frame_duration_ms)
+ return self._snap_to_frame(time_ms)
return int(time_ms)
def get_region_at_pos(self, pos: QPoint):
@@ -791,15 +797,24 @@ def mousePressEvent(self, event: QMouseEvent):
if is_in_track_area:
self.creating_selection_region = True
- playhead_x = self.ms_to_x(self.playhead_pos_ms)
- if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS:
- self.selection_drag_start_ms = self.playhead_pos_ms
+ is_shift_pressed = bool(event.modifiers() & Qt.KeyboardModifier.ShiftModifier)
+ start_ms = self.x_to_ms(event.pos().x())
+
+ if is_shift_pressed:
+ self.selection_drag_start_ms = self._snap_to_frame(start_ms)
else:
- self.selection_drag_start_ms = self.x_to_ms(event.pos().x())
+ playhead_x = self.ms_to_x(self.playhead_pos_ms)
+ if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS:
+ self.selection_drag_start_ms = self.playhead_pos_ms
+ else:
+ self.selection_drag_start_ms = start_ms
self.selection_regions.append([self.selection_drag_start_ms, self.selection_drag_start_ms])
elif is_on_timescale:
time_ms = max(0, self.x_to_ms(event.pos().x()))
- self.playhead_pos_ms = self._snap_time_if_needed(time_ms)
+ if event.modifiers() & Qt.KeyboardModifier.ShiftModifier:
+ self.playhead_pos_ms = self._snap_to_frame(time_ms)
+ else:
+ self.playhead_pos_ms = self._snap_time_if_needed(time_ms)
self.playhead_moved.emit(self.playhead_pos_ms)
self.dragging_playhead = True
@@ -816,6 +831,8 @@ def mouseMoveEvent(self, event: QMouseEvent):
if self.resizing_selection_region:
current_ms = max(0, self.x_to_ms(event.pos().x()))
+ if event.modifiers() & Qt.KeyboardModifier.ShiftModifier:
+ current_ms = self._snap_to_frame(current_ms)
original_start, original_end = self.resize_selection_start_values
if self.resize_selection_edge == 'left':
@@ -836,6 +853,7 @@ def mouseMoveEvent(self, event: QMouseEvent):
self.update()
return
if self.resizing_clip:
+ is_shift_pressed = bool(event.modifiers() & Qt.KeyboardModifier.ShiftModifier)
linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None)
delta_x = event.pos().x() - self.resize_start_pos.x()
time_delta = delta_x / self.pixels_per_ms
@@ -858,11 +876,14 @@ def mouseMoveEvent(self, event: QMouseEvent):
original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_ms
true_new_start_ms = original_start + time_delta
- new_start_ms = true_new_start_ms
- for snap_point in snap_points:
- if abs(true_new_start_ms - snap_point) < snap_time_delta:
- new_start_ms = snap_point
- break
+ if is_shift_pressed:
+ new_start_ms = self._snap_to_frame(true_new_start_ms)
+ else:
+ new_start_ms = true_new_start_ms
+ for snap_point in snap_points:
+ if abs(true_new_start_ms - snap_point) < snap_time_delta:
+ new_start_ms = snap_point
+ break
if new_start_ms > original_start + original_duration - min_duration_ms:
new_start_ms = original_start + original_duration - min_duration_ms
@@ -896,11 +917,14 @@ def mouseMoveEvent(self, event: QMouseEvent):
true_new_duration = original_duration + time_delta
true_new_end_time = original_start + true_new_duration
- new_end_time = true_new_end_time
- for snap_point in snap_points:
- if abs(true_new_end_time - snap_point) < snap_time_delta:
- new_end_time = snap_point
- break
+ if is_shift_pressed:
+ new_end_time = self._snap_to_frame(true_new_end_time)
+ else:
+ new_end_time = true_new_end_time
+ for snap_point in snap_points:
+ if abs(true_new_end_time - snap_point) < snap_time_delta:
+ new_end_time = snap_point
+ break
new_duration = new_end_time - original_start
@@ -948,6 +972,8 @@ def mouseMoveEvent(self, event: QMouseEvent):
if self.creating_selection_region:
current_ms = self.x_to_ms(event.pos().x())
+ if event.modifiers() & Qt.KeyboardModifier.ShiftModifier:
+ current_ms = self._snap_to_frame(current_ms)
start = min(self.selection_drag_start_ms, current_ms)
end = max(self.selection_drag_start_ms, current_ms)
self.selection_regions[-1] = [start, end]
@@ -962,6 +988,9 @@ def mouseMoveEvent(self, event: QMouseEvent):
duration = original_end - original_start
new_start = max(0, original_start + time_delta)
+ if event.modifiers() & Qt.KeyboardModifier.ShiftModifier:
+ new_start = self._snap_to_frame(new_start)
+
self.dragging_selection_region[0] = new_start
self.dragging_selection_region[1] = new_start + duration
@@ -970,7 +999,10 @@ def mouseMoveEvent(self, event: QMouseEvent):
if self.dragging_playhead:
time_ms = max(0, self.x_to_ms(event.pos().x()))
- self.playhead_pos_ms = self._snap_time_if_needed(time_ms)
+ if event.modifiers() & Qt.KeyboardModifier.ShiftModifier:
+ self.playhead_pos_ms = self._snap_to_frame(time_ms)
+ else:
+ self.playhead_pos_ms = self._snap_time_if_needed(time_ms)
self.playhead_moved.emit(self.playhead_pos_ms)
self.update()
elif self.dragging_clip:
@@ -997,17 +1029,20 @@ def mouseMoveEvent(self, event: QMouseEvent):
delta_x = event.pos().x() - self.drag_start_pos.x()
time_delta = delta_x / self.pixels_per_ms
true_new_start_time = original_start_ms + time_delta
-
- playhead_time = self.playhead_pos_ms
- snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms
- new_start_time = true_new_start_time
- true_new_end_time = true_new_start_time + self.dragging_clip.duration_ms
-
- if abs(true_new_start_time - playhead_time) < snap_time_delta:
- new_start_time = playhead_time
- elif abs(true_new_end_time - playhead_time) < snap_time_delta:
- new_start_time = playhead_time - self.dragging_clip.duration_ms
+ if event.modifiers() & Qt.KeyboardModifier.ShiftModifier:
+ new_start_time = self._snap_to_frame(true_new_start_time)
+ else:
+ playhead_time = self.playhead_pos_ms
+ snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms
+
+ new_start_time = true_new_start_time
+ true_new_end_time = true_new_start_time + self.dragging_clip.duration_ms
+
+ if abs(true_new_start_time - playhead_time) < snap_time_delta:
+ new_start_time = playhead_time
+ elif abs(true_new_end_time - playhead_time) < snap_time_delta:
+ new_start_time = playhead_time - self.dragging_clip.duration_ms
for other_clip in self.timeline.clips:
if other_clip.id == self.dragging_clip.id: continue
From 19b6ae344e4c15c7d08eccbc3745ead094bba080 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Mon, 13 Oct 2025 00:09:11 +1100
Subject: [PATCH 145/155] improve timescale
---
videoeditor.py | 64 +++++++++++++++++++++++++++++++++++++++++++-------
1 file changed, 55 insertions(+), 9 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index 53f3baa85..e6176bec9 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -439,23 +439,69 @@ def draw_headers(self, painter):
painter.restore()
- def _format_timecode(self, total_ms):
+
+ def _format_timecode(self, total_ms, interval_ms):
if abs(total_ms) < 1: total_ms = 0
sign = "-" if total_ms < 0 else ""
total_ms = abs(total_ms)
seconds = total_ms / 1000.0
+
+ # Frame-based formatting for high zoom
+ is_frame_based = interval_ms < (1000.0 / self.project_fps) * 5
+ if is_frame_based:
+ total_frames = int(round(seconds * self.project_fps))
+ fps_int = int(round(self.project_fps))
+ if fps_int == 0: fps_int = 25 # Avoid division by zero
+ s_frames = total_frames % fps_int
+ total_seconds_from_frames = total_frames // fps_int
+ h_fr = total_seconds_from_frames // 3600
+ m_fr = (total_seconds_from_frames % 3600) // 60
+ s_fr = total_seconds_from_frames % 60
+
+ if h_fr > 0: return f"{sign}{h_fr}:{m_fr:02d}:{s_fr:02d}:{s_frames:02d}"
+ if m_fr > 0: return f"{sign}{m_fr}:{s_fr:02d}:{s_frames:02d}"
+ return f"{sign}{s_fr}:{s_frames:02d}"
+
+ # Sub-second formatting
+ if interval_ms < 1000:
+ precision = 2 if interval_ms < 100 else 1
+ s_float = seconds % 60
+ m = int((seconds % 3600) / 60)
+ h = int(seconds / 3600)
+
+ if h > 0: return f"{sign}{h}:{m:02d}:{s_float:0{4+precision}.{precision}f}"
+ if m > 0: return f"{sign}{m}:{s_float:0{4+precision}.{precision}f}"
+
+ val = f"{s_float:.{precision}f}"
+ if '.' in val: val = val.rstrip('0').rstrip('.')
+ return f"{sign}{val}s"
+
+ # Seconds and M:SS formatting
+ if interval_ms < 60000:
+ rounded_seconds = int(round(seconds))
+ h = rounded_seconds // 3600
+ m = (rounded_seconds % 3600) // 60
+ s = rounded_seconds % 60
+
+ if h > 0:
+ return f"{sign}{h}:{m:02d}:{s:02d}"
+
+ # Use "Xs" format for times under a minute
+ if rounded_seconds < 60:
+ return f"{sign}{rounded_seconds}s"
+
+ # Use "M:SS" format for times over a minute
+ return f"{sign}{m}:{s:02d}"
+
+ # Minute and Hh:MMm formatting for low zoom
h = int(seconds / 3600)
m = int((seconds % 3600) / 60)
- s = seconds % 60
-
+
if h > 0: return f"{sign}{h}h:{m:02d}m"
- if m > 0 or seconds >= 59.99: return f"{sign}{m}m:{int(round(s)):02d}s"
- precision = 2 if s < 1 else 1 if s < 10 else 0
- val = f"{s:.{precision}f}"
- if '.' in val: val = val.rstrip('0').rstrip('.')
- return f"{sign}{val}s"
+ total_minutes = int(seconds / 60)
+ return f"{sign}{total_minutes}m"
def draw_timescale(self, painter):
painter.save()
@@ -510,7 +556,7 @@ def draw_ticks(interval_ms, height):
if x > self.width() + 50: break
if x >= self.HEADER_WIDTH - 50:
painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT)
- label = self._format_timecode(t_ms)
+ label = self._format_timecode(t_ms, major_interval)
label_width = font_metrics.horizontalAdvance(label)
label_x = x - label_width // 2
if label_x < self.HEADER_WIDTH:
From 4452582d6eb8d22b3b21ca91a5bda4d7ff01ec3e Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Mon, 13 Oct 2025 01:23:09 +1100
Subject: [PATCH 146/155] clamp timescale conversions
---
videoeditor.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/videoeditor.py b/videoeditor.py
index e6176bec9..d6e744ebe 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -330,7 +330,7 @@ def set_project_fps(self, fps):
self.pixels_per_ms = min(self.pixels_per_ms, self.max_pixels_per_ms)
self.update()
- def ms_to_x(self, ms): return self.HEADER_WIDTH + int((ms - self.view_start_ms) * self.pixels_per_ms)
+ def ms_to_x(self, ms): return self.HEADER_WIDTH + int(max(-500_000_000, min((ms - self.view_start_ms) * self.pixels_per_ms, 500_000_000)))
def x_to_ms(self, x): return self.view_start_ms + int(float(x - self.HEADER_WIDTH) / self.pixels_per_ms) if x > self.HEADER_WIDTH and self.pixels_per_ms > 0 else self.view_start_ms
def paintEvent(self, event):
@@ -1861,7 +1861,7 @@ def __init__(self, project_to_load=None):
self.plugin_manager.discover_and_load_plugins()
# Project settings with defaults
- self.project_fps = 25.0
+ self.project_fps = 50.0
self.project_width = 1280
self.project_height = 720
From baad78fd629f4bb5cce39625c5b5d1f1051fb0bc Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Tue, 14 Oct 2025 19:00:24 +1100
Subject: [PATCH 147/155] added audio for playback, major overhaul of encoder
and playback system to use multithreading and better layering, move wgp
import out of video player and into plugin
---
encoding.py | 210 +++++++++++++++
playback.py | 531 +++++++++++++++++++++++++++++++++++++
plugins.py | 53 ++--
plugins/wan2gp/main.py | 179 ++++++++-----
plugins/wan2gp/plugin.json | 5 +
videoeditor.py | 380 +++++++++-----------------
6 files changed, 1022 insertions(+), 336 deletions(-)
create mode 100644 encoding.py
create mode 100644 playback.py
create mode 100644 plugins/wan2gp/plugin.json
diff --git a/encoding.py b/encoding.py
new file mode 100644
index 000000000..22f67489c
--- /dev/null
+++ b/encoding.py
@@ -0,0 +1,210 @@
+import ffmpeg
+import subprocess
+import re
+from PyQt6.QtCore import QObject, pyqtSignal, QThread
+
+class _ExportRunner(QObject):
+ progress = pyqtSignal(int)
+ finished = pyqtSignal(bool, str)
+
+ def __init__(self, ffmpeg_cmd, total_duration_ms, parent=None):
+ super().__init__(parent)
+ self.ffmpeg_cmd = ffmpeg_cmd
+ self.total_duration_ms = total_duration_ms
+ self.process = None
+
+ def run(self):
+ try:
+ startupinfo = None
+ if hasattr(subprocess, 'STARTUPINFO'):
+ startupinfo = subprocess.STARTUPINFO()
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
+
+ self.process = subprocess.Popen(
+ self.ffmpeg_cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ universal_newlines=True,
+ encoding="utf-8",
+ errors='ignore',
+ startupinfo=startupinfo
+ )
+
+ full_output = []
+ time_pattern = re.compile(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})")
+
+ for line in iter(self.process.stdout.readline, ""):
+ full_output.append(line)
+ match = time_pattern.search(line)
+ if match:
+ h, m, s, cs = [int(g) for g in match.groups()]
+ processed_ms = (h * 3600 + m * 60 + s) * 1000 + cs * 10
+ if self.total_duration_ms > 0:
+ percentage = int((processed_ms / self.total_duration_ms) * 100)
+ self.progress.emit(min(100, percentage))
+
+ self.process.stdout.close()
+ return_code = self.process.wait()
+
+ if return_code == 0:
+ self.progress.emit(100)
+ self.finished.emit(True, "Export completed successfully!")
+ else:
+ print("--- FFmpeg Export FAILED ---")
+ print("Command: " + " ".join(self.ffmpeg_cmd))
+ print("".join(full_output))
+ self.finished.emit(False, f"Export failed with code {return_code}. Check console.")
+
+ except FileNotFoundError:
+ self.finished.emit(False, "Export failed: ffmpeg.exe not found in your system's PATH.")
+ except Exception as e:
+ self.finished.emit(False, f"An exception occurred during export: {e}")
+
+ def get_process(self):
+ return self.process
+
+class Encoder(QObject):
+ progress = pyqtSignal(int)
+ finished = pyqtSignal(bool, str)
+
+ def __init__(self, parent=None):
+ super().__init__(parent)
+ self.worker_thread = None
+ self.worker = None
+ self._is_running = False
+
+ def start_export(self, timeline, project_settings, export_settings):
+ if self._is_running:
+ self.finished.emit(False, "An export is already in progress.")
+ return
+
+ self._is_running = True
+
+ try:
+ total_dur_ms = timeline.get_total_duration()
+ total_dur_sec = total_dur_ms / 1000.0
+ w, h, fps = project_settings['width'], project_settings['height'], project_settings['fps']
+ sample_rate, channel_layout = '44100', 'stereo'
+
+ # --- VIDEO GRAPH CONSTRUCTION (Definitive Solution) ---
+
+ all_video_clips = sorted(
+ [c for c in timeline.clips if c.track_type == 'video'],
+ key=lambda c: c.track_index
+ )
+
+ # 1. Start with a base black canvas that defines the project's duration.
+ # This is our master clock and bottom layer.
+ final_video = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fps}:d={total_dur_sec}', f='lavfi')
+
+ # 2. Process each clip individually and overlay it.
+ for clip in all_video_clips:
+ # A new, separate input for every single clip guarantees no conflicts.
+ if clip.media_type == 'image':
+ clip_input = ffmpeg.input(clip.source_path, loop=1, framerate=fps)
+ else:
+ clip_input = ffmpeg.input(clip.source_path)
+
+ # a) Calculate the time shift to align the clip's content with the timeline.
+ # This ensures the correct frame of the source is shown at the start of the clip.
+ timeline_start_sec = clip.timeline_start_ms / 1000.0
+ clip_start_sec = clip.clip_start_ms / 1000.0
+ time_shift_sec = timeline_start_sec - clip_start_sec
+
+ # b) Prepare the layer: apply the time shift, then scale and pad.
+ timed_layer = (
+ clip_input.video
+ .setpts(f'PTS+{time_shift_sec}/TB')
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ )
+
+ # c) Define the visibility window for the overlay on the master timeline.
+ timeline_end_sec = (clip.timeline_start_ms + clip.duration_ms) / 1000.0
+ enable_expression = f'between(t,{timeline_start_sec:.6f},{timeline_end_sec:.6f})'
+
+ # d) Overlay the prepared layer onto the composition, enabling it only during its time window.
+ # eof_action='pass' handles finite streams gracefully.
+ final_video = ffmpeg.overlay(final_video, timed_layer, enable=enable_expression, eof_action='pass')
+
+ # 3. Set final output format and framerate.
+ final_video = final_video.filter('format', pix_fmts='yuv420p').filter('fps', fps=fps)
+
+
+ # --- AUDIO GRAPH CONSTRUCTION (UNCHANGED and CORRECT) ---
+ track_audio_streams = []
+ for i in range(1, timeline.num_audio_tracks + 1):
+ track_clips = sorted([c for c in timeline.clips if c.track_type == 'audio' and c.track_index == i], key=lambda c: c.timeline_start_ms)
+ if not track_clips:
+ continue
+
+ track_segments = []
+ last_end_ms = 0
+ for clip in track_clips:
+ gap_ms = clip.timeline_start_ms - last_end_ms
+ if gap_ms > 10:
+ track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap_ms/1000.0}', f='lavfi'))
+
+ clip_start_sec = clip.clip_start_ms / 1000.0
+ clip_duration_sec = clip.duration_ms / 1000.0
+ audio_source_node = ffmpeg.input(clip.source_path)
+ a_seg = audio_source_node.audio.filter('atrim', start=clip_start_sec, duration=clip_duration_sec).filter('asetpts', 'PTS-STARTPTS')
+ track_segments.append(a_seg)
+ last_end_ms = clip.timeline_start_ms + clip.duration_ms
+
+ if track_segments:
+ track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1))
+
+ # --- FINAL OUTPUT ASSEMBLY ---
+ output_args = {}
+ stream_args = []
+ has_audio = bool(track_audio_streams) and export_settings.get('acodec')
+
+ if export_settings.get('vcodec'):
+ stream_args.append(final_video)
+ output_args['vcodec'] = export_settings['vcodec']
+ if export_settings.get('v_bitrate'): output_args['b:v'] = export_settings['v_bitrate']
+
+ if has_audio:
+ final_audio = ffmpeg.filter(track_audio_streams, 'amix', inputs=len(track_audio_streams), duration='longest')
+ stream_args.append(final_audio)
+ output_args['acodec'] = export_settings['acodec']
+ if export_settings.get('a_bitrate'): output_args['b:a'] = export_settings['a_bitrate']
+
+ if not has_audio:
+ output_args['an'] = None
+
+ if not stream_args:
+ raise ValueError("No streams to output. Check export settings.")
+
+ ffmpeg_cmd = ffmpeg.output(*stream_args, export_settings['output_path'], **output_args).overwrite_output().compile()
+
+ except Exception as e:
+ self.finished.emit(False, f"Error building FFmpeg command: {e}")
+ self._is_running = False
+ return
+
+ self.worker_thread = QThread()
+ self.worker = _ExportRunner(ffmpeg_cmd, total_dur_ms)
+ self.worker.moveToThread(self.worker_thread)
+
+ self.worker.progress.connect(self.progress.emit)
+ self.worker.finished.connect(self._on_export_runner_finished)
+
+ self.worker_thread.started.connect(self.worker.run)
+ self.worker_thread.start()
+
+ def _on_export_runner_finished(self, success, message):
+ self._is_running = False
+ self.finished.emit(success, message)
+
+ if self.worker_thread:
+ self.worker_thread.quit()
+ self.worker_thread.wait()
+ self.worker_thread = None
+ self.worker = None
+
+ def cancel_export(self):
+ if self.worker and self.worker.get_process() and self.worker.get_process().poll() is None:
+ self.worker.get_process().terminate()
+ print("Export cancelled by user.")
\ No newline at end of file
diff --git a/playback.py b/playback.py
new file mode 100644
index 000000000..d8da682e8
--- /dev/null
+++ b/playback.py
@@ -0,0 +1,531 @@
+import ffmpeg
+import numpy as np
+import sounddevice as sd
+import threading
+import time
+from queue import Queue, Empty
+import subprocess
+from PyQt6.QtCore import QObject, pyqtSignal, QTimer
+from PyQt6.QtGui import QImage, QPixmap, QColor
+
+AUDIO_BUFFER_SECONDS = 1.0
+VIDEO_BUFFER_SECONDS = 1.0
+AUDIO_CHUNK_SAMPLES = 1024
+DEFAULT_SAMPLE_RATE = 44100
+DEFAULT_CHANNELS = 2
+
+class PlaybackManager(QObject):
+ new_frame = pyqtSignal(QPixmap)
+ playback_pos_changed = pyqtSignal(int)
+ stopped = pyqtSignal()
+ started = pyqtSignal()
+ paused = pyqtSignal()
+ stats_updated = pyqtSignal(str)
+
+ def __init__(self, get_timeline_data_func, parent=None):
+ super().__init__(parent)
+ self.get_timeline_data = get_timeline_data_func
+
+ self.is_playing = False
+ self.is_muted = False
+ self.volume = 1.0
+ self.playback_start_time_ms = 0
+ self.pause_time_ms = 0
+ self.is_paused = False
+ self.seeking = False
+ self.stop_flag = threading.Event()
+ self._seek_thread = None
+ self._seek_request_ms = -1
+ self._seek_lock = threading.Lock()
+
+ self.video_process = None
+ self.audio_process = None
+ self.video_reader_thread = None
+ self.audio_reader_thread = None
+ self.audio_stream = None
+
+ self.video_queue = None
+ self.audio_queue = None
+
+ self.stream_start_time_monotonic = 0
+ self.last_emitted_pos = -1
+
+ self.total_samples_played = 0
+ self.audio_underruns = 0
+ self.last_video_pts_ms = 0
+ self.debug = False
+
+ self.sync_lock = threading.Lock()
+ self.audio_clock_sec = 0.0
+ self.audio_clock_update_time = 0.0
+
+ self.update_timer = QTimer(self)
+ self.update_timer.timeout.connect(self._update_loop)
+ self.update_timer.setInterval(16)
+
+ def _emit_playhead_pos(self, time_ms, source):
+ if time_ms != self.last_emitted_pos:
+ if self.debug:
+ print(f"[PLAYHEAD DEBUG @ {time.monotonic():.4f}] pos={time_ms}ms, source='{source}'")
+ self.playback_pos_changed.emit(time_ms)
+ self.last_emitted_pos = time_ms
+
+ def _cleanup_resources(self):
+ self.stop_flag.set()
+
+ if self.update_timer.isActive():
+ self.update_timer.stop()
+
+ if self.audio_stream:
+ try:
+ self.audio_stream.stop()
+ self.audio_stream.close(ignore_errors=True)
+ except Exception as e:
+ print(f"Error closing audio stream: {e}")
+ self.audio_stream = None
+
+ for p in [self.video_process, self.audio_process]:
+ if p and p.poll() is None:
+ try:
+ p.terminate()
+ p.wait(timeout=1.0) # Wait a bit for graceful termination
+ except Exception as e:
+ print(f"Error terminating process: {e}")
+ try:
+ p.kill() # Force kill if terminate fails
+ except Exception as ke:
+ print(f"Error killing process: {ke}")
+
+
+ self.video_reader_thread = None
+ self.audio_reader_thread = None
+ self.video_process = None
+ self.audio_process = None
+ self.video_queue = None
+ self.audio_queue = None
+
+ def _video_reader_thread(self, process, width, height, fps):
+ frame_size = width * height * 3
+ frame_duration_ms = 1000.0 / fps
+ frame_pts_ms = self.playback_start_time_ms
+
+ while not self.stop_flag.is_set():
+ try:
+ if self.video_queue and self.video_queue.full():
+ time.sleep(0.01)
+ continue
+
+ chunk = process.stdout.read(frame_size)
+ if len(chunk) < frame_size:
+ break
+
+ if self.video_queue: self.video_queue.put((chunk, frame_pts_ms))
+ frame_pts_ms += frame_duration_ms
+
+ except (IOError, ValueError):
+ break
+ except Exception as e:
+ print(f"Video reader error: {e}")
+ break
+ if self.debug: print("Video reader thread finished.")
+ if self.video_queue: self.video_queue.put(None)
+
+ def _audio_reader_thread(self, process):
+ chunk_size = AUDIO_CHUNK_SAMPLES * DEFAULT_CHANNELS * 4
+ while not self.stop_flag.is_set():
+ try:
+ if self.audio_queue and self.audio_queue.full():
+ time.sleep(0.01)
+ continue
+
+ chunk = process.stdout.read(chunk_size)
+ if not chunk:
+ break
+
+ np_chunk = np.frombuffer(chunk, dtype=np.float32)
+
+ if self.audio_queue:
+ self.audio_queue.put(np_chunk)
+
+ except (IOError, ValueError):
+ break
+ except Exception as e:
+ print(f"Audio reader error: {e}")
+ break
+ if self.debug: print("Audio reader thread finished.")
+ if self.audio_queue: self.audio_queue.put(None)
+
+ def _audio_callback(self, outdata, frames, time_info, status):
+ if status:
+ self.audio_underruns += 1
+ try:
+ if self.audio_queue is None: raise Empty
+ chunk = self.audio_queue.get_nowait()
+ if chunk is None:
+ raise sd.CallbackStop
+
+ chunk_len = len(chunk)
+ outdata_len = outdata.shape[0] * outdata.shape[1]
+
+ if chunk_len < outdata_len:
+ outdata.fill(0)
+ outdata.flat[:chunk_len] = chunk
+ else:
+ outdata[:] = chunk[:outdata.size].reshape(outdata.shape)
+
+ if self.is_muted:
+ outdata.fill(0)
+ else:
+ outdata *= self.volume
+
+ with self.sync_lock:
+ self.audio_clock_sec = self.total_samples_played / DEFAULT_SAMPLE_RATE
+ self.audio_clock_update_time = time_info.outputBufferDacTime
+
+ self.total_samples_played += frames
+ except Empty:
+ outdata.fill(0)
+
+ def _seek_worker(self):
+ while True:
+ with self._seek_lock:
+ time_ms = self._seek_request_ms
+ self._seek_request_ms = -1
+
+ timeline, clips, proj_settings = self.get_timeline_data()
+ w, h, fps = proj_settings['width'], proj_settings['height'], proj_settings['fps']
+
+ top_clip = next((c for c in sorted(clips, key=lambda x: x.track_index, reverse=True)
+ if c.track_type == 'video' and c.timeline_start_ms <= time_ms < (c.timeline_start_ms + c.duration_ms)), None)
+
+ pixmap = QPixmap(w, h)
+ pixmap.fill(QColor("black"))
+
+ if top_clip:
+ try:
+ if top_clip.media_type == 'image':
+ out, _ = (ffmpeg.input(top_clip.source_path)
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True))
+ else:
+ clip_time_sec = (time_ms - top_clip.timeline_start_ms + top_clip.clip_start_ms) / 1000.0
+ out, _ = (ffmpeg.input(top_clip.source_path, ss=f"{clip_time_sec:.6f}")
+ .filter('scale', w, h, force_original_aspect_ratio='decrease')
+ .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
+ .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
+ .run(capture_stdout=True, quiet=True))
+
+ if out:
+ image = QImage(out, w, h, QImage.Format.Format_RGB888)
+ pixmap = QPixmap.fromImage(image)
+
+ except ffmpeg.Error as e:
+ print(f"Error seeking to frame: {e.stderr.decode('utf-8') if e.stderr else str(e)}")
+
+ self.new_frame.emit(pixmap)
+
+ with self._seek_lock:
+ if self._seek_request_ms == -1:
+ self.seeking = False
+ break
+
+ def seek_to_frame(self, time_ms):
+ self.stop()
+ self._emit_playhead_pos(time_ms, "seek_to_frame")
+
+ with self._seek_lock:
+ self._seek_request_ms = time_ms
+ if self.seeking:
+ return
+ self.seeking = True
+ self._seek_thread = threading.Thread(target=self._seek_worker, daemon=True)
+ self._seek_thread.start()
+
+ def _build_video_graph(self, start_ms, timeline, clips, proj_settings):
+ w, h, fps = proj_settings['width'], proj_settings['height'], proj_settings['fps']
+
+ video_clips = [c for c in clips if c.track_type == 'video' and c.timeline_end_ms > start_ms]
+ if not video_clips:
+ return None
+ video_clips = sorted(
+ [c for c in clips if c.track_type == 'video' and c.timeline_end_ms > start_ms],
+ key=lambda c: c.track_index,
+ reverse=True
+ )
+ if not video_clips:
+ return None
+
+ class VSegment:
+ def __init__(self, c, t_start, t_end):
+ self.source_path = c.source_path
+ self.duration_ms = t_end - t_start
+ self.timeline_start_ms = t_start
+ self.media_type = c.media_type
+ self.clip_start_ms = c.clip_start_ms + (t_start - c.timeline_start_ms)
+
+ @property
+ def timeline_end_ms(self):
+ return self.timeline_start_ms + self.duration_ms
+
+ event_points = {start_ms}
+ for c in video_clips:
+ event_points.update({c.timeline_start_ms, c.timeline_end_ms})
+
+ total_duration = timeline.get_total_duration()
+ if total_duration > start_ms:
+ event_points.add(total_duration)
+
+ sorted_points = sorted([p for p in list(event_points) if p >= start_ms])
+
+ visible_segments = []
+ for t_start, t_end in zip(sorted_points[:-1], sorted_points[1:]):
+ if t_end <= t_start: continue
+
+ midpoint = t_start + 1
+ top_clip = next((c for c in video_clips if c.timeline_start_ms <= midpoint < c.timeline_end_ms), None)
+
+ if top_clip:
+ visible_segments.append(VSegment(top_clip, t_start, t_end))
+
+ top_track_clips = visible_segments
+
+ concat_inputs = []
+ last_end_time_ms = start_ms
+ for clip in top_track_clips:
+ clip_read_start_ms = clip.clip_start_ms + max(0, start_ms - clip.timeline_start_ms)
+ clip_play_start_ms = max(start_ms, clip.timeline_start_ms)
+ clip_remaining_duration_ms = clip.timeline_end_ms - clip_play_start_ms
+ if clip_remaining_duration_ms <= 0:
+ last_end_time_ms = max(last_end_time_ms, clip.timeline_end_ms)
+ continue
+ input_stream = ffmpeg.input(clip.source_path, ss=clip_read_start_ms / 1000.0, t=clip_remaining_duration_ms / 1000.0, re=None)
+ if clip.media_type == 'image':
+ segment = input_stream.video.filter('loop', loop=-1, size=1, start=0).filter('setpts', 'N/(FRAME_RATE*TB)').filter('trim', duration=clip_remaining_duration_ms / 1000.0)
+ else:
+ segment = input_stream.video
+ gap_start_ms = max(start_ms, last_end_time_ms)
+ if clip.timeline_start_ms > gap_start_ms:
+ gap_duration_sec = (clip.timeline_start_ms - gap_start_ms) / 1000.0
+ segment = segment.filter('tpad', start_duration=f'{gap_duration_sec:.6f}', color='black')
+ concat_inputs.append(segment)
+ last_end_time_ms = clip.timeline_end_ms
+ if not concat_inputs:
+ return None
+ return ffmpeg.concat(*concat_inputs, v=1, a=0)
+
+ def _build_audio_graph(self, start_ms, timeline, clips, proj_settings):
+ active_clips = [c for c in clips if c.track_type == 'audio' and c.timeline_end_ms > start_ms]
+ if not active_clips:
+ return None
+
+ tracks = {}
+ for clip in active_clips:
+ tracks.setdefault(clip.track_index, []).append(clip)
+
+ track_streams = []
+ for track_index, track_clips in sorted(tracks.items()):
+ track_clips.sort(key=lambda c: c.timeline_start_ms)
+
+ concat_inputs = []
+ last_end_time_ms = start_ms
+
+ for clip in track_clips:
+ gap_start_ms = max(start_ms, last_end_time_ms)
+ if clip.timeline_start_ms > gap_start_ms:
+ gap_duration_sec = (clip.timeline_start_ms - gap_start_ms) / 1000.0
+ concat_inputs.append(ffmpeg.input(f'anullsrc=r={DEFAULT_SAMPLE_RATE}:cl={DEFAULT_CHANNELS}:d={gap_duration_sec}', f='lavfi'))
+
+ clip_read_start_ms = clip.clip_start_ms + max(0, start_ms - clip.timeline_start_ms)
+ clip_play_start_ms = max(start_ms, clip.timeline_start_ms)
+ clip_remaining_duration_ms = clip.timeline_end_ms - clip_play_start_ms
+
+ if clip_remaining_duration_ms > 0:
+ segment = ffmpeg.input(clip.source_path, ss=clip_read_start_ms/1000.0, t=clip_remaining_duration_ms/1000.0, re=None).audio
+ concat_inputs.append(segment)
+
+ last_end_time_ms = clip.timeline_end_ms
+
+ if concat_inputs:
+ track_streams.append(ffmpeg.concat(*concat_inputs, v=0, a=1))
+
+ if not track_streams:
+ return None
+
+ return ffmpeg.filter(track_streams, 'amix', inputs=len(track_streams), duration='longest')
+
+ def play(self, time_ms):
+ if self.is_playing:
+ self.stop()
+
+ if self.debug: print(f"Play requested from {time_ms}ms.")
+
+ self.stop_flag.clear()
+ self.playback_start_time_ms = time_ms
+ self.pause_time_ms = time_ms
+ self.is_paused = False
+
+ self.total_samples_played = 0
+ self.audio_underruns = 0
+ self.last_video_pts_ms = time_ms
+ with self.sync_lock:
+ self.audio_clock_sec = 0.0
+ self.audio_clock_update_time = 0.0
+
+ timeline, clips, proj_settings = self.get_timeline_data()
+ w, h, fps = proj_settings['width'], proj_settings['height'], proj_settings['fps']
+
+ video_buffer_size = int(fps * VIDEO_BUFFER_SECONDS)
+ audio_buffer_size = int((DEFAULT_SAMPLE_RATE / AUDIO_CHUNK_SAMPLES) * AUDIO_BUFFER_SECONDS)
+ self.video_queue = Queue(maxsize=video_buffer_size)
+ self.audio_queue = Queue(maxsize=audio_buffer_size)
+
+ video_graph = self._build_video_graph(time_ms, timeline, clips, proj_settings)
+ if video_graph:
+ try:
+ args = (video_graph
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=fps).compile())
+ self.video_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
+ except Exception as e:
+ print(f"Failed to start video process: {e}")
+ self.video_process = None
+
+ audio_graph = self._build_audio_graph(time_ms, timeline, clips, proj_settings)
+ if audio_graph:
+ try:
+ args = ffmpeg.output(audio_graph, 'pipe:', format='f32le', ac=DEFAULT_CHANNELS, ar=DEFAULT_SAMPLE_RATE).compile()
+ self.audio_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
+ except Exception as e:
+ print(f"Failed to start audio process: {e}")
+ self.audio_process = None
+
+ if self.video_process:
+ self.video_reader_thread = threading.Thread(target=self._video_reader_thread, args=(self.video_process, w, h, fps), daemon=True)
+ self.video_reader_thread.start()
+ if self.audio_process:
+ self.audio_reader_thread = threading.Thread(target=self._audio_reader_thread, args=(self.audio_process,), daemon=True)
+ self.audio_reader_thread.start()
+ self.audio_stream = sd.OutputStream(samplerate=DEFAULT_SAMPLE_RATE, channels=DEFAULT_CHANNELS, callback=self._audio_callback, blocksize=AUDIO_CHUNK_SAMPLES)
+ self.audio_stream.start()
+
+ self.stream_start_time_monotonic = time.monotonic()
+ self.is_playing = True
+ self.update_timer.start()
+ self.started.emit()
+
+ def pause(self):
+ if not self.is_playing or self.is_paused:
+ return
+ if self.debug: print("Playback paused.")
+ self.is_paused = True
+ if self.audio_stream:
+ self.audio_stream.stop()
+ self.pause_time_ms = self._get_current_time_ms()
+ self.update_timer.stop()
+ self.paused.emit()
+
+ def resume(self):
+ if not self.is_playing or not self.is_paused:
+ return
+ if self.debug: print("Playback resumed.")
+ self.is_paused = False
+
+ time_since_start_sec = (self.pause_time_ms - self.playback_start_time_ms) / 1000.0
+ self.stream_start_time_monotonic = time.monotonic() - time_since_start_sec
+
+ if self.audio_stream:
+ self.audio_stream.start()
+
+ self.update_timer.start()
+ self.started.emit()
+
+ def stop(self):
+ was_playing = self.is_playing
+ if was_playing and self.debug: print("Playback stopped.")
+ self._cleanup_resources()
+ self.is_playing = False
+ self.is_paused = False
+ if was_playing:
+ self.stopped.emit()
+
+ def set_volume(self, value):
+ self.volume = max(0.0, min(1.0, value))
+
+ def set_muted(self, muted):
+ self.is_muted = bool(muted)
+
+ def _get_current_time_ms(self):
+ with self.sync_lock:
+ audio_clk_sec = self.audio_clock_sec
+ audio_clk_update_time = self.audio_clock_update_time
+
+ if self.audio_stream and self.audio_stream.active and audio_clk_update_time > 0:
+ time_since_dac_start = max(0, time.monotonic() - audio_clk_update_time)
+ current_audio_time_sec = audio_clk_sec + time_since_dac_start
+ current_pos_ms = self.playback_start_time_ms + int(current_audio_time_sec * 1000.0)
+ else:
+ elapsed_sec = time.monotonic() - self.stream_start_time_monotonic
+ current_pos_ms = self.playback_start_time_ms + int(elapsed_sec * 1000)
+
+ return current_pos_ms
+
+ def _update_loop(self):
+ if not self.is_playing or self.is_paused:
+ return
+
+ current_pos_ms = self._get_current_time_ms()
+
+ clock_source = "SYSTEM"
+ with self.sync_lock:
+ if self.audio_stream and self.audio_stream.active and self.audio_clock_update_time > 0:
+ clock_source = "AUDIO"
+
+ if abs(current_pos_ms - self.last_emitted_pos) >= 20:
+ source_str = f"_update_loop (clock:{clock_source})"
+ self._emit_playhead_pos(current_pos_ms, source_str)
+
+ if self.video_queue:
+ while not self.video_queue.empty():
+ try:
+ # Peek at the next frame
+ if self.video_queue.queue[0] is None:
+ self.stop() # End of stream
+ break
+ _, frame_pts = self.video_queue.queue[0]
+
+ if frame_pts <= current_pos_ms:
+ frame_bytes, frame_pts = self.video_queue.get_nowait()
+ if frame_bytes is None: # Should be caught by peek, but for safety
+ self.stop()
+ break
+ self.last_video_pts_ms = frame_pts
+ _, _, proj_settings = self.get_timeline_data()
+ w, h = proj_settings['width'], proj_settings['height']
+ img = QImage(frame_bytes, w, h, QImage.Format.Format_RGB888)
+ self.new_frame.emit(QPixmap.fromImage(img))
+ else:
+ # Frame is in the future, wait for next loop
+ break
+ except Empty:
+ break
+ except IndexError:
+ break # Queue might be empty between check and access
+
+ # --- Update Stats ---
+ vq_size = self.video_queue.qsize() if self.video_queue else 0
+ vq_max = self.video_queue.maxsize if self.video_queue else 0
+ aq_size = self.audio_queue.qsize() if self.audio_queue else 0
+ aq_max = self.audio_queue.maxsize if self.audio_queue else 0
+
+ video_audio_sync_ms = int(self.last_video_pts_ms - current_pos_ms)
+
+ stats_str = (f"AQ: {aq_size}/{aq_max} | VQ: {vq_size}/{vq_max} | "
+ f"V-A Δ: {video_audio_sync_ms}ms | Clock: {clock_source} | "
+ f"Underruns: {self.audio_underruns}")
+ self.stats_updated.emit(stats_str)
+
+ total_duration = self.get_timeline_data()[0].get_total_duration()
+ if total_duration > 0 and current_pos_ms >= total_duration:
+ self.stop()
+ self._emit_playhead_pos(int(total_duration), "_update_loop.end_of_timeline")
\ No newline at end of file
diff --git a/plugins.py b/plugins.py
index f40f6adb1..a4e0699dd 100644
--- a/plugins.py
+++ b/plugins.py
@@ -5,47 +5,59 @@
import shutil
import git
import json
+import importlib
from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QListWidget, QListWidgetItem,
QPushButton, QLabel, QLineEdit, QMessageBox, QProgressBar,
QDialogButtonBox, QWidget, QCheckBox)
from PyQt6.QtCore import Qt, QObject, pyqtSignal, QThread
class VideoEditorPlugin:
- """
- Base class for all plugins.
- Plugins should inherit from this class and be located in 'plugins/plugin_name/main.py'.
- The main class in main.py must be named 'Plugin'.
- """
- def __init__(self, app_instance, wgp_module=None):
+ def __init__(self, app_instance):
self.app = app_instance
- self.wgp = wgp_module
self.name = "Unnamed Plugin"
self.description = "No description provided."
def initialize(self):
- """Called once when the plugin is loaded by the PluginManager."""
pass
def enable(self):
- """Called when the plugin is enabled by the user (e.g., checking the box in the menu)."""
pass
def disable(self):
- """Called when the plugin is disabled by the user."""
pass
class PluginManager:
- """Manages the discovery, loading, and lifecycle of plugins."""
- def __init__(self, main_app, wgp_module=None):
+ def __init__(self, main_app):
self.app = main_app
- self.wgp = wgp_module
self.plugins_dir = "plugins"
- self.plugins = {} # { 'plugin_name': {'instance': plugin_instance, 'enabled': False, 'module_path': path} }
+ self.plugins = {}
if not os.path.exists(self.plugins_dir):
os.makedirs(self.plugins_dir)
+ def run_preloads(plugins_dir="plugins"):
+ print("Running plugin pre-load scan...")
+ if not os.path.isdir(plugins_dir):
+ return
+
+ for plugin_name in os.listdir(plugins_dir):
+ plugin_path = os.path.join(plugins_dir, plugin_name)
+ manifest_path = os.path.join(plugin_path, 'plugin.json')
+ if os.path.isfile(manifest_path):
+ try:
+ with open(manifest_path, 'r') as f:
+ manifest = json.load(f)
+
+ preload_modules = manifest.get("preload_modules", [])
+ if preload_modules:
+ for module_name in preload_modules:
+ try:
+ importlib.import_module(module_name)
+ except ImportError as e:
+ print(f" - WARNING: Could not pre-load '{module_name}'. Error: {e}")
+ except Exception as e:
+ print(f" - WARNING: Could not read or parse manifest for plugin '{plugin_name}': {e}")
+
def discover_and_load_plugins(self):
- """Scans the plugins directory, loads valid plugins, and calls their initialize method."""
for plugin_name in os.listdir(self.plugins_dir):
plugin_path = os.path.join(self.plugins_dir, plugin_name)
main_py_path = os.path.join(plugin_path, 'main.py')
@@ -60,7 +72,7 @@ def discover_and_load_plugins(self):
if hasattr(module, 'Plugin'):
plugin_class = getattr(module, 'Plugin')
- instance = plugin_class(self.app, self.wgp)
+ instance = plugin_class(self.app)
instance.initialize()
self.plugins[instance.name] = {
'instance': instance,
@@ -76,17 +88,14 @@ def discover_and_load_plugins(self):
traceback.print_exc()
def load_enabled_plugins_from_settings(self, enabled_plugins_list):
- """Enables plugins based on the loaded settings."""
for name in enabled_plugins_list:
if name in self.plugins:
self.enable_plugin(name)
def get_enabled_plugin_names(self):
- """Returns a list of names of all enabled plugins."""
return [name for name, data in self.plugins.items() if data['enabled']]
def enable_plugin(self, name):
- """Enables a specific plugin by name and updates the UI."""
if name in self.plugins and not self.plugins[name]['enabled']:
self.plugins[name]['instance'].enable()
self.plugins[name]['enabled'] = True
@@ -95,7 +104,6 @@ def enable_plugin(self, name):
self.app.update_plugin_ui_visibility(name, True)
def disable_plugin(self, name):
- """Disables a specific plugin by name and updates the UI."""
if name in self.plugins and self.plugins[name]['enabled']:
self.plugins[name]['instance'].disable()
self.plugins[name]['enabled'] = False
@@ -104,7 +112,6 @@ def disable_plugin(self, name):
self.app.update_plugin_ui_visibility(name, False)
def uninstall_plugin(self, name):
- """Uninstalls (deletes) a plugin by name."""
if name in self.plugins:
path = self.plugins[name]['module_path']
if self.plugins[name]['enabled']:
@@ -174,12 +181,10 @@ def __init__(self, plugin_manager, parent=None):
self.status_label = QLabel("Ready.")
layout.addWidget(self.status_label)
-
- # Dialog buttons
+
self.button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Save | QDialogButtonBox.StandardButton.Cancel)
layout.addWidget(self.button_box)
- # Connections
self.install_btn.clicked.connect(self.install_plugin)
self.button_box.accepted.connect(self.save_changes)
self.button_box.rejected.connect(self.reject)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 473baf2e4..2a38230dd 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -50,10 +50,11 @@ def __getattr__(self, name):
sys.modules['shared.gradio.gallery'] = MockGradioModule()
# --- End of Gradio Hijacking ---
+wgp = None
import ffmpeg
from PyQt6.QtWidgets import (
- QWidget, QVBoxLayout, QHBoxLayout, QTabWidget,
+ QApplication, QWidget, QVBoxLayout, QHBoxLayout, QTabWidget,
QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox,
QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem,
QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem,
@@ -213,7 +214,6 @@ class Worker(QObject):
def __init__(self, plugin, state):
super().__init__()
self.plugin = plugin
- self.wgp = plugin.wgp
self.state = state
self._is_running = True
self._last_progress_phase = None
@@ -222,7 +222,7 @@ def __init__(self, plugin, state):
def run(self):
def generation_target():
try:
- for _ in self.wgp.process_tasks(self.state):
+ for _ in wgp.process_tasks(self.state):
if self._is_running: self.output.emit()
else: break
except Exception as e:
@@ -243,7 +243,7 @@ def generation_target():
phase_name, step = current_phase
total_steps = gen.get("num_inference_steps", 1)
high_level_status = gen.get("progress_status", "")
- status_msg = self.wgp.merge_status_context(high_level_status, phase_name)
+ status_msg = wgp.merge_status_context(high_level_status, phase_name)
progress_args = [(step, total_steps), status_msg]
self.progress.emit(progress_args)
preview_img = gen.get('preview')
@@ -260,7 +260,6 @@ class WgpDesktopPluginWidget(QWidget):
def __init__(self, plugin):
super().__init__()
self.plugin = plugin
- self.wgp = plugin.wgp
self.widgets = {}
self.state = {}
self.worker = None
@@ -739,7 +738,7 @@ def _setup_adv_tab_misc(self, tabs):
layout.addRow("Force FPS:", fps_combo)
profile_combo = self.create_widget(QComboBox, 'override_profile')
profile_combo.addItem("Default Profile", -1)
- for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val)
+ for text, val in wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val)
layout.addRow("Override Memory Profile:", profile_combo)
combo = self.create_widget(QComboBox, 'multi_prompts_gen_type')
combo.addItem("Generate new Video per line", 0)
@@ -777,7 +776,7 @@ def _create_scrollable_form_tab(self):
def _create_config_combo(self, form_layout, label, key, choices, default_value):
combo = QComboBox()
for text, data in choices: combo.addItem(text, data)
- index = combo.findData(self.wgp.server_config.get(key, default_value))
+ index = combo.findData(wgp.server_config.get(key, default_value))
if index != -1: combo.setCurrentIndex(index)
self.widgets[f'config_{key}'] = combo
form_layout.addRow(label, combo)
@@ -789,7 +788,7 @@ def _create_config_slider(self, form_layout, label, key, min_val, max_val, defau
slider = QSlider(Qt.Orientation.Horizontal)
slider.setRange(min_val, max_val)
slider.setSingleStep(step)
- slider.setValue(self.wgp.server_config.get(key, default_value))
+ slider.setValue(wgp.server_config.get(key, default_value))
value_label = QLabel(str(slider.value()))
value_label.setMinimumWidth(40)
slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v)))
@@ -801,7 +800,7 @@ def _create_config_slider(self, form_layout, label, key, min_val, max_val, defau
def _create_config_checklist(self, form_layout, label, key, choices, default_value):
list_widget = QListWidget()
list_widget.setMinimumHeight(100)
- current_values = self.wgp.server_config.get(key, default_value)
+ current_values = wgp.server_config.get(key, default_value)
for text, data in choices:
item = QListWidgetItem(text)
item.setData(Qt.ItemDataRole.UserRole, data)
@@ -822,8 +821,8 @@ def _create_config_textbox(self, form_layout, label, key, default_value, multi_l
def _create_general_config_tab(self):
tab, form = self._create_scrollable_form_tab()
- _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False)
- self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types)
+ _, _, dropdown_choices = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False)
+ self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, wgp.transformer_types)
self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1)
self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0)
self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto")
@@ -832,7 +831,7 @@ def _create_general_config_tab(self):
self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5)
self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0)
self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1)
- checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths))
+ checkpoints_paths_text = "\n".join(wgp.server_config.get("checkpoints_paths", wgp.fl.default_checkpoints_paths))
checkpoints_textbox = QTextEdit()
checkpoints_textbox.setPlainText(checkpoints_paths_text)
checkpoints_textbox.setAcceptRichText(False)
@@ -853,7 +852,7 @@ def _create_performance_config_tab(self):
self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl")
self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0)
self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1)
- self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM)
+ self._create_config_combo(form, "Memory Profile:", "profile", wgp.memory_profile_choices, wgp.profile_type.LowRAM_LowVRAM)
self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100)
release_ram_btn = QPushButton("Force Release Models from RAM")
release_ram_btn.clicked.connect(self._on_release_ram)
@@ -882,7 +881,6 @@ def _create_notifications_config_tab(self):
return tab
def init_wgp_state(self):
- wgp = self.wgp
initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type)
dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
_, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
@@ -903,14 +901,16 @@ def init_wgp_state(self):
self._update_input_visibility()
def update_model_dropdowns(self, current_model_type):
- wgp = self.wgp
family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type)
for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]:
combo = self.widgets[combo_name]
combo.blockSignals(True)
combo.clear()
if mock.choices:
- for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key)
+ for item in mock.choices:
+ if isinstance(item, (list, tuple)) and len(item) >= 2:
+ display_name, internal_key = item[0], item[1]
+ combo.addItem(display_name, internal_key)
index = combo.findData(mock.value)
if index != -1: combo.setCurrentIndex(index)
@@ -925,7 +925,6 @@ def update_model_dropdowns(self, current_model_type):
def refresh_ui_from_model_change(self, model_type):
"""Update UI controls with default settings when the model is changed."""
- wgp = self.wgp
self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode))
ui_defaults = wgp.get_default_settings(model_type)
wgp.set_model_settings(self.state, model_type, ui_defaults)
@@ -1286,7 +1285,7 @@ def _on_preview_toggled(self, checked):
def _on_family_changed(self):
family = self.widgets['model_family'].currentData()
if not family or not self.state: return
- base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family)
+ base_type_mock, choice_mock = wgp.change_model_family(self.state, family)
if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict):
is_visible_base = base_type_mock.kwargs.get('visible', True)
@@ -1324,7 +1323,7 @@ def _on_base_type_changed(self):
family = self.widgets['model_family'].currentData()
base_type = self.widgets['model_base_type_choice'].currentData()
if not family or not base_type or not self.state: return
- base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type)
+ base_type_mock, choice_mock = wgp.change_model_base_types(self.state, family, base_type)
if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict):
is_visible_choice = choice_mock.kwargs.get('visible', True)
@@ -1345,18 +1344,18 @@ def _on_base_type_changed(self):
def _on_model_changed(self):
model_type = self.widgets['model_choice'].currentData()
if not model_type or model_type == self.state.get('model_type'): return
- self.wgp.change_model(self.state, model_type)
+ wgp.change_model(self.state, model_type)
self.refresh_ui_from_model_change(model_type)
def _on_resolution_group_changed(self):
selected_group = self.widgets['resolution_group'].currentText()
if not selected_group or not hasattr(self, 'full_resolution_choices'): return
model_type = self.state['model_type']
- model_def = self.wgp.get_model_def(model_type)
+ model_def = wgp.get_model_def(model_type)
model_resolutions = model_def.get("resolutions", None)
group_resolution_choices = []
if model_resolutions is None:
- group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group]
+ group_resolution_choices = [res for res in self.full_resolution_choices if wgp.categorize_resolution(res[1]) == selected_group]
else: return
last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "")
if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices:
@@ -1398,7 +1397,7 @@ def set_resolution_from_target(self, target_w, target_h):
best_res_value = res_value
if best_res_value:
- best_group = self.wgp.categorize_resolution(best_res_value)
+ best_group = wgp.categorize_resolution(best_res_value)
group_combo = self.widgets['resolution_group']
group_combo.blockSignals(True)
@@ -1417,7 +1416,7 @@ def set_resolution_from_target(self, target_w, target_h):
print(f"Warning: Could not find resolution '{best_res_value}' in dropdown after group change.")
def collect_inputs(self):
- full_inputs = self.wgp.get_current_model_settings(self.state).copy()
+ full_inputs = wgp.get_current_model_settings(self.state).copy()
full_inputs['lset_name'] = ""
full_inputs['image_mode'] = 0
expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, }
@@ -1512,9 +1511,9 @@ def _add_task_to_queue(self):
all_inputs = self.collect_inputs()
for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None)
all_inputs['state'] = self.state
- self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs)
+ wgp.set_model_settings(self.state, self.state['model_type'], all_inputs)
self.state["validate_success"] = 1
- self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type'])
+ wgp.process_prompt_and_add_tasks(self.state, self.state['model_type'])
self.update_queue_table()
def start_generation(self):
@@ -1582,11 +1581,11 @@ def add_result_item(self, video_path):
self.results_list.setItemWidget(list_item, item_widget)
def update_queue_table(self):
- with self.wgp.lock:
+ with wgp.lock:
queue = self.state.get('gen', {}).get('queue', [])
is_running = self.thread and self.thread.isRunning()
queue_to_display = queue if is_running else [None] + queue
- table_data = self.wgp.get_queue_table(queue_to_display)
+ table_data = wgp.get_queue_table(queue_to_display)
self.queue_table.setRowCount(0)
self.queue_table.setRowCount(len(table_data))
self.queue_table.setColumnCount(4)
@@ -1601,7 +1600,7 @@ def update_queue_table(self):
def _on_remove_selected_from_queue(self):
selected_row = self.queue_table.currentRow()
if selected_row < 0: return
- with self.wgp.lock:
+ with wgp.lock:
is_running = self.thread and self.thread.isRunning()
offset = 1 if is_running else 0
queue = self.state.get('gen', {}).get('queue', [])
@@ -1609,7 +1608,7 @@ def _on_remove_selected_from_queue(self):
self.update_queue_table()
def _on_queue_rows_moved(self, source_row, dest_row):
- with self.wgp.lock:
+ with wgp.lock:
queue = self.state.get('gen', {}).get('queue', [])
is_running = self.thread and self.thread.isRunning()
offset = 1 if is_running else 0
@@ -1620,17 +1619,17 @@ def _on_queue_rows_moved(self, source_row, dest_row):
self.update_queue_table()
def _on_clear_queue(self):
- self.wgp.clear_queue_action(self.state)
+ wgp.clear_queue_action(self.state)
self.update_queue_table()
def _on_abort(self):
if self.worker:
- self.wgp.abort_generation(self.state)
+ wgp.abort_generation(self.state)
self.status_label.setText("Aborting...")
self.worker._is_running = False
def _on_release_ram(self):
- self.wgp.release_RAM()
+ wgp.release_RAM()
QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.")
def _on_apply_config_changes(self):
@@ -1655,12 +1654,12 @@ def _on_apply_config_changes(self):
changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value()
changes['last_resolution_choice'] = self.widgets['resolution'].currentData()
try:
- msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes)
+ msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = wgp.apply_changes(self.state, **changes)
self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.")
- self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode))
+ self.header_info.setText(wgp.generate_header(self.state['model_type'], wgp.compile, wgp.attention_mode))
if family_mock.choices is not None or choice_mock.choices is not None:
- self.update_model_dropdowns(self.wgp.transformer_type)
- self.refresh_ui_from_model_change(self.wgp.transformer_type)
+ self.update_model_dropdowns(wgp.transformer_type)
+ self.refresh_ui_from_model_change(wgp.transformer_type)
except Exception as e:
self.config_status_label.setText(f"Error applying changes: {e}")
import traceback; traceback.print_exc()
@@ -1669,23 +1668,82 @@ class Plugin(VideoEditorPlugin):
def initialize(self):
self.name = "AI Generator"
self.description = "Uses the integrated Wan2GP library to generate video clips."
- self.client_widget = WgpDesktopPluginWidget(self)
+ self.client_widget = None
self.dock_widget = None
- self.active_region = None; self.temp_dir = None
- self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None
+ self._heavy_content_loaded = False
+
+ self.active_region = None
+ self.temp_dir = None
+ self.insert_on_new_track = False
+ self.start_frame_path = None
+ self.end_frame_path = None
def enable(self):
- if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name)
+ if not self.dock_widget:
+ placeholder = QLabel("Loading AI Generator...")
+ placeholder.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ self.dock_widget = self.app.add_dock_widget(self, placeholder, self.name)
+ self.dock_widget.visibilityChanged.connect(self._on_visibility_changed)
+
self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu)
self.app.status_label.setText(f"{self.name}: Enabled.")
+ def _on_visibility_changed(self, visible):
+ if visible and not self._heavy_content_loaded:
+ self._load_heavy_ui()
+
+ def _load_heavy_ui(self):
+ if self._heavy_content_loaded:
+ return True
+
+ self.app.status_label.setText("Loading AI Generator...")
+ QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor)
+ QApplication.processEvents()
+ try:
+ global wgp
+ if wgp is None:
+ import wgp as wgp_module
+ wgp = wgp_module
+
+ self.client_widget = WgpDesktopPluginWidget(self)
+ self.dock_widget.setWidget(self.client_widget)
+ self._heavy_content_loaded = True
+ self.app.status_label.setText("AI Generator loaded.")
+ return True
+ except Exception as e:
+ print(f"Failed to load AI Generator plugin backend: {e}")
+ import traceback
+ traceback.print_exc()
+ error_label = QLabel(f"Failed to load AI Generator:\n\n{e}\n\nPlease see console for details.")
+ error_label.setAlignment(Qt.AlignmentFlag.AlignCenter)
+ error_label.setWordWrap(True)
+ if self.dock_widget:
+ self.dock_widget.setWidget(error_label)
+ QMessageBox.critical(self.app, "Plugin Load Error", f"Failed to load the AI Generator backend.\n\n{e}")
+ return False
+ finally:
+ QApplication.restoreOverrideCursor()
+
def disable(self):
try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu)
except TypeError: pass
+
+ if self.dock_widget:
+ try: self.dock_widget.visibilityChanged.disconnect(self._on_visibility_changed)
+ except TypeError: pass
+
self._cleanup_temp_dir()
- if self.client_widget.worker: self.client_widget._on_abort()
+ if self.client_widget and self.client_widget.worker:
+ self.client_widget._on_abort()
+
self.app.status_label.setText(f"{self.name}: Disabled.")
+ def _ensure_ui_loaded(self):
+ if not self._heavy_content_loaded:
+ if not self._load_heavy_ui():
+ return False
+ return True
+
def _cleanup_temp_dir(self):
if self.temp_dir and os.path.exists(self.temp_dir):
shutil.rmtree(self.temp_dir)
@@ -1694,11 +1752,12 @@ def _cleanup_temp_dir(self):
def _reset_state(self):
self.active_region = None; self.insert_on_new_track = False
self.start_frame_path = None; self.end_frame_path = None
- self.client_widget.processed_files.clear()
- self.client_widget.results_list.clear()
- self.client_widget.widgets['image_start'].clear()
- self.client_widget.widgets['image_end'].clear()
- self.client_widget.widgets['video_source'].clear()
+ if self._heavy_content_loaded:
+ self.client_widget.processed_files.clear()
+ self.client_widget.results_list.clear()
+ self.client_widget.widgets['image_start'].clear()
+ self.client_widget.widgets['image_end'].clear()
+ self.client_widget.widgets['video_source'].clear()
self._cleanup_temp_dir()
self.app.status_label.setText(f"{self.name}: Ready.")
@@ -1732,13 +1791,14 @@ def on_timeline_context_menu(self, menu, event):
create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True))
def setup_generator_for_region(self, region, on_new_track=False):
+ if not self._ensure_ui_loaded(): return
self._reset_state()
self.active_region = region
self.insert_on_new_track = on_new_track
model_to_set = 'i2v_2_2'
- dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
- _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
+ _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
if any(model_to_set == m[1] for m in all_models):
if self.client_widget.state.get('model_type') != model_to_set:
self.client_widget.update_model_dropdowns(model_to_set)
@@ -1763,7 +1823,6 @@ def setup_generator_for_region(self, region, on_new_track=False):
QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
duration_ms = end_ms - start_ms
- wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
@@ -1792,13 +1851,14 @@ def setup_generator_for_region(self, region, on_new_track=False):
self.dock_widget.raise_()
def setup_generator_from_start(self, region, on_new_track=False):
+ if not self._ensure_ui_loaded(): return
self._reset_state()
self.active_region = region
self.insert_on_new_track = on_new_track
model_to_set = 'i2v_2_2'
- dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
- _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
+ _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
if any(model_to_set == m[1] for m in all_models):
if self.client_widget.state.get('model_type') != model_to_set:
self.client_widget.update_model_dropdowns(model_to_set)
@@ -1820,7 +1880,6 @@ def setup_generator_from_start(self, region, on_new_track=False):
QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path)
duration_ms = end_ms - start_ms
- wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
@@ -1851,13 +1910,14 @@ def setup_generator_from_start(self, region, on_new_track=False):
self.dock_widget.raise_()
def setup_generator_to_end(self, region, on_new_track=False):
+ if not self._ensure_ui_loaded(): return
self._reset_state()
self.active_region = region
self.insert_on_new_track = on_new_track
model_to_set = 'i2v_2_2'
- dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
- _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
+ _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
if any(model_to_set == m[1] for m in all_models):
if self.client_widget.state.get('model_type') != model_to_set:
self.client_widget.update_model_dropdowns(model_to_set)
@@ -1879,7 +1939,6 @@ def setup_generator_to_end(self, region, on_new_track=False):
QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path)
duration_ms = end_ms - start_ms
- wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
@@ -1887,7 +1946,7 @@ def setup_generator_to_end(self, region, on_new_track=False):
widgets['video_length'].setValue(video_length_frames)
- model_def = self.client_widget.wgp.get_model_def(self.client_widget.state['model_type'])
+ model_def = wgp.get_model_def(self.client_widget.state['model_type'])
allowed_modes = model_def.get("image_prompt_types_allowed", "")
if "E" not in allowed_modes:
@@ -1921,13 +1980,14 @@ def setup_generator_to_end(self, region, on_new_track=False):
self.dock_widget.raise_()
def setup_creator_for_region(self, region, on_new_track=False):
+ if not self._ensure_ui_loaded(): return
self._reset_state()
self.active_region = region
self.insert_on_new_track = on_new_track
model_to_set = 't2v_2_2'
- dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types
- _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False)
+ dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types
+ _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False)
if any(model_to_set == m[1] for m in all_models):
if self.client_widget.state.get('model_type') != model_to_set:
self.client_widget.update_model_dropdowns(model_to_set)
@@ -1941,7 +2001,6 @@ def setup_creator_for_region(self, region, on_new_track=False):
start_ms, end_ms = region
duration_ms = end_ms - start_ms
- wgp = self.client_widget.wgp
model_type = self.client_widget.state['model_type']
fps = wgp.get_model_fps(model_type)
video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16)
diff --git a/plugins/wan2gp/plugin.json b/plugins/wan2gp/plugin.json
new file mode 100644
index 000000000..47e592c50
--- /dev/null
+++ b/plugins/wan2gp/plugin.json
@@ -0,0 +1,5 @@
+{
+ "preload_modules": [
+ "onnxruntime"
+ ]
+}
\ No newline at end of file
diff --git a/videoeditor.py b/videoeditor.py
index d6e744ebe..21df9e63d 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -1,4 +1,4 @@
-import wgp
+import onnxruntime
import sys
import os
import uuid
@@ -7,19 +7,21 @@
import json
import ffmpeg
import copy
+from plugins import PluginManager, ManagePluginsDialog
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QHBoxLayout, QPushButton, QFileDialog, QLabel,
QScrollArea, QFrame, QProgressBar, QDialog,
QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget,
QListWidget, QListWidgetItem, QMessageBox, QComboBox,
- QFormLayout, QGroupBox, QLineEdit)
+ QFormLayout, QGroupBox, QLineEdit, QSlider)
from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction,
QPixmap, QImage, QDrag, QCursor, QKeyEvent)
from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread,
pyqtSignal, QTimer, QByteArray, QMimeData)
-from plugins import PluginManager, ManagePluginsDialog
from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand
+from playback import PlaybackManager
+from encoding import Encoder
CONTAINER_PRESETS = {
'mp4': {
@@ -210,36 +212,6 @@ def get_total_duration(self):
if not self.clips: return 0
return max(c.timeline_end_ms for c in self.clips)
-class ExportWorker(QObject):
- progress = pyqtSignal(int)
- finished = pyqtSignal(str)
- def __init__(self, ffmpeg_cmd, total_duration_ms):
- super().__init__()
- self.ffmpeg_cmd = ffmpeg_cmd
- self.total_duration_ms = total_duration_ms
-
- def run_export(self):
- try:
- process = subprocess.Popen(self.ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, encoding="utf-8")
- time_pattern = re.compile(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})")
- for line in iter(process.stdout.readline, ""):
- match = time_pattern.search(line)
- if match:
- h, m, s, cs = [int(g) for g in match.groups()]
- processed_ms = (h * 3600 + m * 60 + s) * 1000 + cs * 10
- if self.total_duration_ms > 0:
- percentage = int((processed_ms / self.total_duration_ms) * 100)
- self.progress.emit(percentage)
- process.stdout.close()
- return_code = process.wait()
- if return_code == 0:
- self.progress.emit(100)
- self.finished.emit("Export completed successfully!")
- else:
- self.finished.emit(f"Export failed! Check console for FFmpeg errors.")
- except Exception as e:
- self.finished.emit(f"An exception occurred during export: {e}")
-
class TimelineWidget(QWidget):
TIMESCALE_HEIGHT = 30
HEADER_WIDTH = 120
@@ -333,6 +305,10 @@ def set_project_fps(self, fps):
def ms_to_x(self, ms): return self.HEADER_WIDTH + int(max(-500_000_000, min((ms - self.view_start_ms) * self.pixels_per_ms, 500_000_000)))
def x_to_ms(self, x): return self.view_start_ms + int(float(x - self.HEADER_WIDTH) / self.pixels_per_ms) if x > self.HEADER_WIDTH and self.pixels_per_ms > 0 else self.view_start_ms
+ def set_playhead_pos(self, time_ms):
+ self.playhead_pos_ms = time_ms
+ self.update()
+
def paintEvent(self, event):
painter = QPainter(self)
painter.setRenderHint(QPainter.RenderHint.Antialiasing)
@@ -1848,8 +1824,6 @@ def __init__(self, project_to_load=None):
self.undo_stack = UndoStack()
self.media_pool = []
self.media_properties = {}
- self.export_thread = None
- self.export_worker = None
self.current_project_path = None
self.last_export_path = None
self.settings = {}
@@ -1857,7 +1831,10 @@ def __init__(self, project_to_load=None):
self.is_shutting_down = False
self._load_settings()
- self.plugin_manager = PluginManager(self, wgp)
+ self.playback_manager = PlaybackManager(self._get_playback_data)
+ self.encoder = Encoder()
+
+ self.plugin_manager = PluginManager(self)
self.plugin_manager.discover_and_load_plugins()
# Project settings with defaults
@@ -1869,16 +1846,12 @@ def __init__(self, project_to_load=None):
self.scale_to_fit = True
self.current_preview_pixmap = None
- self.playback_timer = QTimer(self)
- self.playback_process = None
- self.playback_clip = None
-
self._setup_ui()
self._connect_signals()
self.plugin_manager.load_enabled_plugins_from_settings(self.settings.get("enabled_plugins", []))
self._apply_loaded_settings()
- self.seek_preview(0)
+ self.playback_manager.seek_to_frame(0)
if not self.settings_file_was_loaded: self._save_settings()
if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load))
@@ -1893,6 +1866,17 @@ def _get_current_timeline_state(self):
self.timeline.num_video_tracks,
self.timeline.num_audio_tracks
)
+
+ def _get_playback_data(self):
+ return (
+ self.timeline,
+ self.timeline.clips,
+ {
+ 'width': self.project_width,
+ 'height': self.project_height,
+ 'fps': self.project_fps
+ }
+ )
def _setup_ui(self):
self.media_dock = QDockWidget("Project Media", self)
@@ -1946,15 +1930,33 @@ def _setup_ui(self):
controls_layout.addStretch()
main_layout.addWidget(controls_widget)
- status_layout = QHBoxLayout()
+ status_bar_widget = QWidget()
+ status_layout = QHBoxLayout(status_bar_widget)
+ status_layout.setContentsMargins(5,2,5,2)
+
self.status_label = QLabel("Ready. Create or open a project from the File menu.")
+ self.stats_label = QLabel("AQ: 0/0 | VQ: 0/0")
+ self.stats_label.setMinimumWidth(120)
self.progress_bar = QProgressBar()
self.progress_bar.setVisible(False)
self.progress_bar.setTextVisible(True)
self.progress_bar.setRange(0, 100)
+
+ self.mute_button = QPushButton("Mute")
+ self.mute_button.setCheckable(True)
+ self.volume_slider = QSlider(Qt.Orientation.Horizontal)
+ self.volume_slider.setRange(0, 100)
+ self.volume_slider.setValue(100)
+ self.volume_slider.setMaximumWidth(100)
+
status_layout.addWidget(self.status_label, 1)
+ status_layout.addWidget(self.stats_label)
status_layout.addWidget(self.progress_bar, 1)
- main_layout.addLayout(status_layout)
+ status_layout.addStretch()
+ status_layout.addWidget(self.mute_button)
+ status_layout.addWidget(self.volume_slider)
+
+ main_layout.addWidget(status_bar_widget)
self.setCentralWidget(container_widget)
@@ -1978,7 +1980,7 @@ def _connect_signals(self):
self.timeline_widget.split_requested.connect(self.split_clip_at_playhead)
self.timeline_widget.delete_clip_requested.connect(self.delete_clip)
self.timeline_widget.delete_clips_requested.connect(self.delete_clips)
- self.timeline_widget.playhead_moved.connect(self.seek_preview)
+ self.timeline_widget.playhead_moved.connect(self.playback_manager.seek_to_frame)
self.timeline_widget.split_region_requested.connect(self.on_split_region)
self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions)
self.timeline_widget.join_region_requested.connect(self.on_join_region)
@@ -1995,7 +1997,6 @@ def _connect_signals(self):
self.frame_forward_button.clicked.connect(lambda: self.step_frame(1))
self.snap_back_button.clicked.connect(lambda: self.snap_playhead(-1))
self.snap_forward_button.clicked.connect(lambda: self.snap_playhead(1))
- self.playback_timer.timeout.connect(self.advance_playback_frame)
self.project_media_widget.add_media_requested.connect(self.add_media_files)
self.project_media_widget.media_removed.connect(self.on_media_removed_from_pool)
@@ -2004,6 +2005,19 @@ def _connect_signals(self):
self.undo_stack.history_changed.connect(self.update_undo_redo_actions)
self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo)
+ self.playback_manager.new_frame.connect(self._on_new_frame)
+ self.playback_manager.playback_pos_changed.connect(self._on_playback_pos_changed)
+ self.playback_manager.stopped.connect(self._on_playback_stopped)
+ self.playback_manager.started.connect(self._on_playback_started)
+ self.playback_manager.paused.connect(self._on_playback_paused)
+ self.playback_manager.stats_updated.connect(self.stats_label.setText)
+
+ self.encoder.progress.connect(self.progress_bar.setValue)
+ self.encoder.finished.connect(self.on_export_finished)
+
+ self.mute_button.toggled.connect(self._on_mute_toggled)
+ self.volume_slider.valueChanged.connect(self._on_volume_changed)
+
def _show_preview_context_menu(self, pos):
menu = QMenu(self)
scale_action = QAction("Scale to Fit", self, checkable=True)
@@ -2019,7 +2033,7 @@ def _toggle_scale_to_fit(self, checked):
def on_timeline_changed_by_undo(self):
self.prune_empty_tracks()
self.timeline_widget.update()
- self.seek_preview(self.timeline_widget.playhead_pos_ms)
+ self.playback_manager.seek_to_frame(self.timeline_widget.playhead_pos_ms)
self.status_label.setText("Operation undone/redone.")
def update_undo_redo_actions(self):
@@ -2191,41 +2205,6 @@ def _create_menu_bar(self):
data['action'] = action
self.windows_menu.addAction(action)
- def _get_topmost_video_clip_at(self, time_ms):
- """Finds the video clip on the highest track at a specific time."""
- top_clip = None
- for c in self.timeline.clips:
- if c.track_type == 'video' and c.timeline_start_ms <= time_ms < c.timeline_end_ms:
- if top_clip is None or c.track_index > top_clip.track_index:
- top_clip = c
- return top_clip
-
- def _start_playback_stream_at(self, time_ms):
- self._stop_playback_stream()
- clip = self._get_topmost_video_clip_at(time_ms)
- if not clip: return
-
- self.playback_clip = clip
- clip_time_sec = (time_ms - clip.timeline_start_ms + clip.clip_start_ms) / 1000.0
- w, h = self.project_width, self.project_height
- try:
- args = (ffmpeg.input(self.playback_clip.source_path, ss=f"{clip_time_sec:.6f}")
- .filter('scale', w, h, force_original_aspect_ratio='decrease')
- .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
- .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile())
- self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL)
- except Exception as e:
- print(f"Failed to start playback stream: {e}"); self._stop_playback_stream()
-
- def _stop_playback_stream(self):
- if self.playback_process:
- if self.playback_process.poll() is None:
- self.playback_process.terminate()
- try: self.playback_process.wait(timeout=0.5)
- except subprocess.TimeoutExpired: self.playback_process.kill(); self.playback_process.wait()
- self.playback_process = None
- self.playback_clip = None
-
def _get_media_properties(self, file_path):
"""Probes a file to get its media properties. Returns a dict or None."""
if file_path in self.media_properties:
@@ -2307,11 +2286,16 @@ def _probe_for_drag(self, file_path):
return self._get_media_properties(file_path)
def get_frame_data_at_time(self, time_ms):
- clip_at_time = self._get_topmost_video_clip_at(time_ms)
+ """Blocking frame grab for plugin compatibility."""
+ _, clips, proj_settings = self._get_playback_data()
+ w, h = proj_settings['width'], proj_settings['height']
+
+ clip_at_time = next((c for c in sorted(clips, key=lambda x: x.track_index, reverse=True)
+ if c.track_type == 'video' and c.timeline_start_ms <= time_ms < c.timeline_end_ms), None)
+
if not clip_at_time:
return (None, 0, 0)
try:
- w, h = self.project_width, self.project_height
if clip_at_time.media_type == 'image':
out, _ = (
ffmpeg
@@ -2333,41 +2317,16 @@ def get_frame_data_at_time(self, time_ms):
)
return (out, self.project_width, self.project_height)
except ffmpeg.Error as e:
- print(f"Error extracting frame data: {e.stderr}")
+ print(f"Error extracting frame data for plugin: {e.stderr}")
return (None, 0, 0)
- def get_frame_at_time(self, time_ms):
- clip_at_time = self._get_topmost_video_clip_at(time_ms)
- if not clip_at_time: return None
- try:
- w, h = self.project_width, self.project_height
- if clip_at_time.media_type == 'image':
- out, _ = (
- ffmpeg.input(clip_at_time.source_path)
- .filter('scale', w, h, force_original_aspect_ratio='decrease')
- .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
- .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
- .run(capture_stdout=True, quiet=True)
- )
- else:
- clip_time_sec = (time_ms - clip_at_time.timeline_start_ms + clip_at_time.clip_start_ms) / 1000.0
- out, _ = (ffmpeg.input(clip_at_time.source_path, ss=f"{clip_time_sec:.6f}")
- .filter('scale', w, h, force_original_aspect_ratio='decrease')
- .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')
- .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24')
- .run(capture_stdout=True, quiet=True))
-
- image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888)
- return QPixmap.fromImage(image)
- except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return None
-
- def seek_preview(self, time_ms):
- self._stop_playback_stream()
- self.timeline_widget.playhead_pos_ms = int(time_ms)
- self.timeline_widget.update()
- self.current_preview_pixmap = self.get_frame_at_time(time_ms)
+ def _on_new_frame(self, pixmap):
+ self.current_preview_pixmap = pixmap
self._update_preview_display()
+ def _on_playback_pos_changed(self, time_ms):
+ self.timeline_widget.set_playhead_pos(time_ms)
+
def _update_preview_display(self):
pixmap_to_show = self.current_preview_pixmap
if not pixmap_to_show:
@@ -2388,22 +2347,42 @@ def _update_preview_display(self):
self.preview_scroll_area.setWidgetResizable(False)
self.preview_widget.setPixmap(pixmap_to_show)
self.preview_widget.adjustSize()
-
def toggle_playback(self):
- if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play")
+ if not self.timeline.clips:
+ return
+
+ if self.playback_manager.is_playing:
+ if self.playback_manager.is_paused:
+ self.playback_manager.resume()
+ else:
+ self.playback_manager.pause()
else:
- if not self.timeline.clips: return
- if self.timeline_widget.playhead_pos_ms >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_ms = 0
- self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause")
+ current_pos = self.timeline_widget.playhead_pos_ms
+ if current_pos >= self.timeline.get_total_duration():
+ current_pos = 0
+ self.playback_manager.play(current_pos)
+
+ def _on_playback_started(self):
+ self.play_pause_button.setText("Pause")
+
+ def _on_playback_paused(self):
+ self.play_pause_button.setText("Play")
+
+ def _on_playback_stopped(self):
+ self.play_pause_button.setText("Play")
+
+ def stop_playback(self):
+ self.playback_manager.stop()
+ self.playback_manager.seek_to_frame(0)
- def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0)
def step_frame(self, direction):
if not self.timeline.clips: return
- self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream()
+ self.playback_manager.pause()
frame_duration_ms = 1000.0 / self.project_fps
new_time = self.timeline_widget.playhead_pos_ms + (direction * frame_duration_ms)
- self.seek_preview(int(max(0, min(new_time, self.timeline.get_total_duration()))))
+ final_time = int(max(0, min(new_time, self.timeline.get_total_duration())))
+ self.playback_manager.seek_to_frame(final_time)
def snap_playhead(self, direction):
if not self.timeline.clips:
@@ -2423,50 +2402,20 @@ def snap_playhead(self, direction):
if direction == 1: # Forward
next_points = [p for p in sorted_points if p > current_time_ms + TOLERANCE_MS]
if next_points:
- self.seek_preview(next_points[0])
+ self.playback_manager.seek_to_frame(next_points[0])
elif direction == -1: # Backward
prev_points = [p for p in sorted_points if p < current_time_ms - TOLERANCE_MS]
if prev_points:
- self.seek_preview(prev_points[-1])
+ self.playback_manager.seek_to_frame(prev_points[-1])
elif current_time_ms > 0:
- self.seek_preview(0)
-
- def advance_playback_frame(self):
- frame_duration_ms = 1000.0 / self.project_fps
- new_time_ms = self.timeline_widget.playhead_pos_ms + frame_duration_ms
- if new_time_ms > self.timeline.get_total_duration(): self.stop_playback(); return
-
- self.timeline_widget.playhead_pos_ms = round(new_time_ms)
- self.timeline_widget.update()
-
- clip_at_new_time = self._get_topmost_video_clip_at(new_time_ms)
-
- if not clip_at_new_time:
- self._stop_playback_stream()
- self.current_preview_pixmap = None
- self._update_preview_display()
- return
+ self.playback_manager.seek_to_frame(0)
- if clip_at_new_time.media_type == 'image':
- if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
- self._stop_playback_stream()
- self.playback_clip = clip_at_new_time
- self.current_preview_pixmap = self.get_frame_at_time(new_time_ms)
- self._update_preview_display()
- return
+ def _on_volume_changed(self, value):
+ self.playback_manager.set_volume(value / 100.0)
- if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id:
- self._start_playback_stream_at(new_time_ms)
-
- if self.playback_process:
- frame_size = self.project_width * self.project_height * 3
- frame_bytes = self.playback_process.stdout.read(frame_size)
- if len(frame_bytes) == frame_size:
- image = QImage(frame_bytes, self.project_width, self.project_height, QImage.Format.Format_RGB888)
- self.current_preview_pixmap = QPixmap.fromImage(image)
- self._update_preview_display()
- else:
- self._stop_playback_stream()
+ def _on_mute_toggled(self, checked):
+ self.playback_manager.set_muted(checked)
+ self.mute_button.setText("Unmute" if checked else "Mute")
def _load_settings(self):
self.settings_file_was_loaded = False
@@ -2518,9 +2467,10 @@ def open_settings_dialog(self):
self.settings.update(dialog.get_settings()); self._save_settings(); self.status_label.setText("Settings updated.")
def new_project(self):
+ self.playback_manager.stop()
self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1
self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list()
- self.current_project_path = None; self.stop_playback()
+ self.current_project_path = None
self.last_export_path = None
self.project_fps = 25.0
self.project_width = 1280
@@ -2532,6 +2482,7 @@ def new_project(self):
self.undo_stack.history_changed.connect(self.update_undo_redo_actions)
self.update_undo_redo_actions()
self.status_label.setText("New project created. Add media to begin.")
+ self.playback_manager.seek_to_frame(0)
def save_project_as(self):
path, _ = QFileDialog.getSaveFileName(self, "Save Project", "", "JSON Project Files (*.json)")
@@ -2592,7 +2543,8 @@ def _load_project_from_path(self, path):
self.current_project_path = path
self.prune_empty_tracks()
- self.timeline_widget.update(); self.stop_playback()
+ self.timeline_widget.update()
+ self.playback_manager.seek_to_frame(0)
self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.")
self._add_to_recent_files(path)
except Exception as e: self.status_label.setText(f"Error opening project: {e}")
@@ -3045,103 +2997,29 @@ def export_video(self):
self.status_label.setText("Export canceled.")
return
- settings = dialog.get_export_settings()
- output_path = settings["output_path"]
+ export_settings = dialog.get_export_settings()
+ output_path = export_settings["output_path"]
if not output_path:
self.status_label.setText("Export failed: No output path specified.")
return
self.last_export_path = output_path
- total_dur_ms = self.timeline.get_total_duration()
- total_dur_sec = total_dur_ms / 1000.0
- w, h, fr_str = self.project_width, self.project_height, str(self.project_fps)
- sample_rate, channel_layout = '44100', 'stereo'
-
- video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur_sec}', f='lavfi')
-
- all_video_clips = sorted(
- [c for c in self.timeline.clips if c.track_type == 'video'],
- key=lambda c: c.track_index
- )
- input_nodes = {}
- for clip in all_video_clips:
- if clip.source_path not in input_nodes:
- if clip.media_type == 'image':
- input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps)
- else:
- input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
-
- clip_source_node = input_nodes[clip.source_path]
- clip_duration_sec = clip.duration_ms / 1000.0
- clip_start_sec = clip.clip_start_ms / 1000.0
- timeline_start_sec = clip.timeline_start_ms / 1000.0
- timeline_end_sec = clip.timeline_end_ms / 1000.0
-
- if clip.media_type == 'image':
- segment_stream = (clip_source_node.video.trim(duration=clip_duration_sec).setpts('PTS-STARTPTS'))
- else:
- segment_stream = (clip_source_node.video.trim(start=clip_start_sec, duration=clip_duration_sec).setpts('PTS-STARTPTS'))
-
- processed_segment = (segment_stream.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black'))
- video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{timeline_start_sec},{timeline_end_sec})')
-
- final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps)
+ project_settings = {
+ 'width': self.project_width,
+ 'height': self.project_height,
+ 'fps': self.project_fps
+ }
- track_audio_streams = []
- for i in range(self.timeline.num_audio_tracks):
- track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_ms)
- if not track_clips: continue
- track_segments = []
- last_end_ms = track_clips[0].timeline_start_ms
- for clip in track_clips:
- if clip.source_path not in input_nodes:
- input_nodes[clip.source_path] = ffmpeg.input(clip.source_path)
- gap_ms = clip.timeline_start_ms - last_end_ms
- if gap_ms > 10:
- track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap_ms/1000.0}', f='lavfi'))
-
- clip_start_sec = clip.clip_start_ms / 1000.0
- clip_duration_sec = clip.duration_ms / 1000.0
- a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip_start_sec, duration=clip_duration_sec).filter('asetpts', 'PTS-STARTPTS')
- track_segments.append(a_seg)
- last_end_ms = clip.timeline_end_ms
- track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_ms)}ms', all=True))
-
- if track_audio_streams:
- final_audio = ffmpeg.filter(track_audio_streams, 'amix', inputs=len(track_audio_streams), duration='longest')
- else:
- final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur_sec}', f='lavfi')
+ self.progress_bar.setVisible(True)
+ self.progress_bar.setValue(0)
+ self.status_label.setText("Exporting...")
- output_args = {'vcodec': settings['vcodec'], 'acodec': settings['acodec'], 'pix_fmt': 'yuv420p'}
- if settings['v_bitrate']: output_args['b:v'] = settings['v_bitrate']
- if settings['a_bitrate']: output_args['b:a'] = settings['a_bitrate']
-
- try:
- ffmpeg_cmd = ffmpeg.output(final_video, final_audio, output_path, **output_args).overwrite_output().compile()
- self.progress_bar.setVisible(True); self.progress_bar.setValue(0); self.status_label.setText("Exporting...")
- self.export_thread = QThread()
- self.export_worker = ExportWorker(ffmpeg_cmd, total_dur_ms)
- self.export_worker.moveToThread(self.export_thread)
- self.export_thread.started.connect(self.export_worker.run_export)
- self.export_worker.finished.connect(self.on_export_finished)
- self.export_worker.progress.connect(self.progress_bar.setValue)
- self.export_worker.finished.connect(self.export_thread.quit)
- self.export_worker.finished.connect(self.export_worker.deleteLater)
- self.export_thread.finished.connect(self.export_thread.deleteLater)
- self.export_thread.finished.connect(self.on_thread_finished_cleanup)
- self.export_thread.start()
- except ffmpeg.Error as e:
- self.status_label.setText(f"FFmpeg error: {e.stderr}")
- print(e.stderr)
+ self.encoder.start_export(self.timeline, project_settings, export_settings)
- def on_export_finished(self, message):
+ def on_export_finished(self, success, message):
self.status_label.setText(message)
self.progress_bar.setVisible(False)
-
- def on_thread_finished_cleanup(self):
- self.export_thread = None
- self.export_worker = None
def add_dock_widget(self, plugin_instance, widget, title, area=Qt.DockWidgetArea.RightDockWidgetArea, show_on_creation=True):
widget_key = f"plugin_{plugin_instance.name}_{title}".replace(' ', '_').lower()
@@ -3205,14 +3083,12 @@ def closeEvent(self, event):
return
self.is_shutting_down = True
+ self.playback_manager.stop()
self._save_settings()
- self._stop_playback_stream()
- if self.export_thread and self.export_thread.isRunning():
- self.export_thread.quit()
- self.export_thread.wait()
event.accept()
if __name__ == '__main__':
+ PluginManager.run_preloads()
app = QApplication(sys.argv)
project_to_load_on_startup = None
if len(sys.argv) > 1:
From 47a64b5bd11f73ca3fdb51eb8d7256ffbb929e25 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Tue, 14 Oct 2025 19:18:06 +1100
Subject: [PATCH 148/155] remove failed attempt to make onnxruntime import
plugin-based
---
plugins.py | 24 ------------------------
plugins/wan2gp/plugin.json | 5 -----
videoeditor.py | 1 -
3 files changed, 30 deletions(-)
delete mode 100644 plugins/wan2gp/plugin.json
diff --git a/plugins.py b/plugins.py
index a4e0699dd..cde026750 100644
--- a/plugins.py
+++ b/plugins.py
@@ -5,7 +5,6 @@
import shutil
import git
import json
-import importlib
from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QListWidget, QListWidgetItem,
QPushButton, QLabel, QLineEdit, QMessageBox, QProgressBar,
QDialogButtonBox, QWidget, QCheckBox)
@@ -34,29 +33,6 @@ def __init__(self, main_app):
if not os.path.exists(self.plugins_dir):
os.makedirs(self.plugins_dir)
- def run_preloads(plugins_dir="plugins"):
- print("Running plugin pre-load scan...")
- if not os.path.isdir(plugins_dir):
- return
-
- for plugin_name in os.listdir(plugins_dir):
- plugin_path = os.path.join(plugins_dir, plugin_name)
- manifest_path = os.path.join(plugin_path, 'plugin.json')
- if os.path.isfile(manifest_path):
- try:
- with open(manifest_path, 'r') as f:
- manifest = json.load(f)
-
- preload_modules = manifest.get("preload_modules", [])
- if preload_modules:
- for module_name in preload_modules:
- try:
- importlib.import_module(module_name)
- except ImportError as e:
- print(f" - WARNING: Could not pre-load '{module_name}'. Error: {e}")
- except Exception as e:
- print(f" - WARNING: Could not read or parse manifest for plugin '{plugin_name}': {e}")
-
def discover_and_load_plugins(self):
for plugin_name in os.listdir(self.plugins_dir):
plugin_path = os.path.join(self.plugins_dir, plugin_name)
diff --git a/plugins/wan2gp/plugin.json b/plugins/wan2gp/plugin.json
deleted file mode 100644
index 47e592c50..000000000
--- a/plugins/wan2gp/plugin.json
+++ /dev/null
@@ -1,5 +0,0 @@
-{
- "preload_modules": [
- "onnxruntime"
- ]
-}
\ No newline at end of file
diff --git a/videoeditor.py b/videoeditor.py
index 21df9e63d..101b2059a 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -3088,7 +3088,6 @@ def closeEvent(self, event):
event.accept()
if __name__ == '__main__':
- PluginManager.run_preloads()
app = QApplication(sys.argv)
project_to_load_on_startup = None
if len(sys.argv) > 1:
From f151887f7f7ebe46fe4b1ad1e0a98bb4525a37e5 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Thu, 16 Oct 2025 07:19:30 +1100
Subject: [PATCH 149/155] Update README.md
---
README.md | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 81c93be78..ca7c0161e 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-# Inline AI Video Editor
+# Inline AI Video Editor
A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provides a multi-track timeline, a media preview window, and basic clip manipulation capabilities, all wrapped in a dockable user interface. The editor is designed to be extensible through a plugin system.
@@ -58,9 +58,11 @@ python videoeditor.py
```
## Screenshots
-
-
-
+
+
+
+
+
From 7029b660f743765effd8ced523219f80f636fd3e Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Thu, 16 Oct 2025 07:25:05 +1100
Subject: [PATCH 150/155] Update README.md
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ca7c0161e..a37096feb 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-# Inline AI Video Editor
+# Inline AI Video Editor
A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provides a multi-track timeline, a media preview window, and basic clip manipulation capabilities, all wrapped in a dockable user interface. The editor is designed to be extensible through a plugin system.
From ad19ab47f5ed3ac15f076435c6592119a13e01e8 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 17 Oct 2025 12:58:57 +1100
Subject: [PATCH 151/155] use icons for multimedia controls
---
icons/pause.svg | 8 +++++
icons/play.svg | 7 +++++
icons/previous_frame.svg | 7 +++++
icons/snap_to_start.svg | 8 +++++
icons/stop.svg | 5 +++
videoeditor.py | 66 ++++++++++++++++++++++++++++++++--------
6 files changed, 89 insertions(+), 12 deletions(-)
create mode 100644 icons/pause.svg
create mode 100644 icons/play.svg
create mode 100644 icons/previous_frame.svg
create mode 100644 icons/snap_to_start.svg
create mode 100644 icons/stop.svg
diff --git a/icons/pause.svg b/icons/pause.svg
new file mode 100644
index 000000000..df6b56823
--- /dev/null
+++ b/icons/pause.svg
@@ -0,0 +1,8 @@
+
\ No newline at end of file
diff --git a/icons/play.svg b/icons/play.svg
new file mode 100644
index 000000000..bb3a02a8f
--- /dev/null
+++ b/icons/play.svg
@@ -0,0 +1,7 @@
+
\ No newline at end of file
diff --git a/icons/previous_frame.svg b/icons/previous_frame.svg
new file mode 100644
index 000000000..63bae024e
--- /dev/null
+++ b/icons/previous_frame.svg
@@ -0,0 +1,7 @@
+
\ No newline at end of file
diff --git a/icons/snap_to_start.svg b/icons/snap_to_start.svg
new file mode 100644
index 000000000..3624c85cc
--- /dev/null
+++ b/icons/snap_to_start.svg
@@ -0,0 +1,8 @@
+
\ No newline at end of file
diff --git a/icons/stop.svg b/icons/stop.svg
new file mode 100644
index 000000000..7f51b9433
--- /dev/null
+++ b/icons/stop.svg
@@ -0,0 +1,5 @@
+
\ No newline at end of file
diff --git a/videoeditor.py b/videoeditor.py
index 101b2059a..4473d9cc0 100644
--- a/videoeditor.py
+++ b/videoeditor.py
@@ -15,7 +15,7 @@
QListWidget, QListWidgetItem, QMessageBox, QComboBox,
QFormLayout, QGroupBox, QLineEdit, QSlider)
from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction,
- QPixmap, QImage, QDrag, QCursor, QKeyEvent)
+ QPixmap, QImage, QDrag, QCursor, QKeyEvent, QIcon, QTransform)
from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread,
pyqtSignal, QTimer, QByteArray, QMimeData)
@@ -1886,9 +1886,8 @@ def _setup_ui(self):
self.splitter = QSplitter(Qt.Orientation.Vertical)
- # --- PREVIEW WIDGET SETUP (CHANGED) ---
self.preview_scroll_area = QScrollArea()
- self.preview_scroll_area.setWidgetResizable(False) # Important for 1:1 scaling
+ self.preview_scroll_area.setWidgetResizable(False)
self.preview_scroll_area.setStyleSheet("background-color: black; border: 0px;")
self.preview_widget = QLabel()
@@ -1914,12 +1913,55 @@ def _setup_ui(self):
controls_widget = QWidget()
controls_layout = QHBoxLayout(controls_widget)
controls_layout.setContentsMargins(0, 5, 0, 5)
- self.play_pause_button = QPushButton("Play")
- self.stop_button = QPushButton("Stop")
- self.frame_back_button = QPushButton("<")
- self.frame_forward_button = QPushButton(">")
- self.snap_back_button = QPushButton("|<")
- self.snap_forward_button = QPushButton(">|")
+
+ icon_dir = "icons"
+ icon_size = QSize(32, 32)
+ button_size = QSize(40, 40)
+
+ self.play_icon = QIcon(os.path.join(icon_dir, "play.svg"))
+ self.pause_icon = QIcon(os.path.join(icon_dir, "pause.svg"))
+ stop_icon = QIcon(os.path.join(icon_dir, "stop.svg"))
+ back_pixmap = QPixmap(os.path.join(icon_dir, "previous_frame.svg"))
+ snap_back_pixmap = QPixmap(os.path.join(icon_dir, "snap_to_start.svg"))
+
+ transform = QTransform().rotate(180)
+
+ frame_forward_icon = QIcon(back_pixmap.transformed(transform))
+ snap_forward_icon = QIcon(snap_back_pixmap.transformed(transform))
+ frame_back_icon = QIcon(back_pixmap)
+ snap_back_icon = QIcon(snap_back_pixmap)
+
+ self.play_pause_button = QPushButton()
+ self.play_pause_button.setIcon(self.play_icon)
+ self.play_pause_button.setToolTip("Play/Pause")
+
+ self.stop_button = QPushButton()
+ self.stop_button.setIcon(stop_icon)
+ self.stop_button.setToolTip("Stop")
+
+ self.frame_back_button = QPushButton()
+ self.frame_back_button.setIcon(frame_back_icon)
+ self.frame_back_button.setToolTip("Previous Frame (Left Arrow)")
+
+ self.frame_forward_button = QPushButton()
+ self.frame_forward_button.setIcon(frame_forward_icon)
+ self.frame_forward_button.setToolTip("Next Frame (Right Arrow)")
+
+ self.snap_back_button = QPushButton()
+ self.snap_back_button.setIcon(snap_back_icon)
+ self.snap_back_button.setToolTip("Snap to Previous Clip Edge")
+
+ self.snap_forward_button = QPushButton()
+ self.snap_forward_button.setIcon(snap_forward_icon)
+ self.snap_forward_button.setToolTip("Snap to Next Clip Edge")
+
+ button_list = [self.snap_back_button, self.frame_back_button, self.play_pause_button,
+ self.stop_button, self.frame_forward_button, self.snap_forward_button]
+ for btn in button_list:
+ btn.setIconSize(icon_size)
+ btn.setFixedSize(button_size)
+ btn.setStyleSheet("QPushButton { border: none; background-color: transparent; }")
+
controls_layout.addStretch()
controls_layout.addWidget(self.snap_back_button)
controls_layout.addWidget(self.frame_back_button)
@@ -2364,13 +2406,13 @@ def toggle_playback(self):
self.playback_manager.play(current_pos)
def _on_playback_started(self):
- self.play_pause_button.setText("Pause")
+ self.play_pause_button.setIcon(self.pause_icon)
def _on_playback_paused(self):
- self.play_pause_button.setText("Play")
+ self.play_pause_button.setIcon(self.play_icon)
def _on_playback_stopped(self):
- self.play_pause_button.setText("Play")
+ self.play_pause_button.setIcon(self.play_icon)
def stop_playback(self):
self.playback_manager.stop()
From 08a7e00ce3a2336205ec03f835e25f594f09564d Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 17 Oct 2025 13:00:12 +1100
Subject: [PATCH 152/155] Update README.md
---
README.md | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/README.md b/README.md
index a37096feb..a97c6b402 100644
--- a/README.md
+++ b/README.md
@@ -58,11 +58,9 @@ python videoeditor.py
```
## Screenshots
-
-
-
-
-
+
+
+
From b5c7a740b71c344383e1003b0446150aa86671ed Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 17 Oct 2025 20:01:23 +1100
Subject: [PATCH 153/155] fix configuration
---
plugins/wan2gp/main.py | 120 +++++++++++++++++++++++++++++++----------
1 file changed, 92 insertions(+), 28 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 2a38230dd..95d0beeb3 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -827,6 +827,10 @@ def _create_general_config_tab(self):
self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0)
self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto")
self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata")
+ checkbox = QCheckBox()
+ checkbox.setChecked(wgp.server_config.get("embed_source_images", False))
+ self.widgets['config_embed_source_images'] = checkbox
+ form.addRow("Embed Source Images in MP4:", checkbox)
self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], [])
self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5)
self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0)
@@ -839,6 +843,7 @@ def _create_general_config_tab(self):
self.widgets['config_checkpoints_paths'] = checkpoints_textbox
form.addRow("Checkpoints Paths:", checkpoints_textbox)
self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default")
+ self._create_config_combo(form, "Queue Color Scheme:", "queue_color_scheme", [("Pastel (Unique color per item)", "pastel"), ("Alternating Grey Shades", "alternating_grey")], "pastel")
return tab
def _create_performance_config_tab(self):
@@ -870,8 +875,9 @@ def _create_outputs_config_tab(self):
tab, form = self._create_scrollable_form_tab()
self._create_config_combo(form, "Video Codec:", "video_output_codec", [("x265 Balanced", 'libx265_28'), ("x264 Balanced", 'libx264_8'), ("x265 High Quality", 'libx265_8'), ("x264 High Quality", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], 'libx264_8')
self._create_config_combo(form, "Image Codec:", "image_output_codec", [("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], 'jpeg_95')
- self._create_config_textbox(form, "Video Output Folder:", "save_path", "outputs")
- self._create_config_textbox(form, "Image Output Folder:", "image_save_path", "outputs")
+ self._create_config_combo(form, "Audio Codec:", "audio_output_codec", [("AAC 128 kbit", 'aac_128')], 'aac_128')
+ self._create_config_textbox(form, "Video Output Folder:", "save_path", wgp.server_config.get("save_path", "outputs"))
+ self._create_config_textbox(form, "Image Output Folder:", "image_save_path", wgp.server_config.get("image_save_path", "outputs"))
return tab
def _create_notifications_config_tab(self):
@@ -1419,7 +1425,8 @@ def collect_inputs(self):
full_inputs = wgp.get_current_model_settings(self.state).copy()
full_inputs['lset_name'] = ""
full_inputs['image_mode'] = 0
- expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, }
+ full_inputs['mode'] = ""
+ expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, "embedded_guidance_scale": None, "model_mode": None, "control_net_weight": 1.0, "control_net_weight2": 1.0, "mask_expand": 0, "remove_background_images_ref": 0, "prompt_enhancer": ""}
for key, default_value in expected_keys.items():
if key not in full_inputs: full_inputs[key] = default_value
full_inputs['prompt'] = self.widgets['prompt'].toPlainText()
@@ -1633,33 +1640,90 @@ def _on_release_ram(self):
QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.")
def _on_apply_config_changes(self):
- changes = {}
- list_widget = self.widgets['config_transformer_types']
- changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
- list_widget = self.widgets['config_preload_model_policy']
- changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
- changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData()
- changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText()
- for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]:
- changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
- for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]:
- changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
- changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value()
- for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]:
- changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData()
- for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]:
- widget = self.widgets[f'config_{key}']
- changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text()
- changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData()
- changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value()
- changes['last_resolution_choice'] = self.widgets['resolution'].currentData()
+ if wgp.args.lock_config:
+ self.config_status_label.setText("Configuration is locked by command-line arguments.")
+ return
+ if self.thread and self.thread.isRunning():
+ self.config_status_label.setText("Cannot change config while a generation is in progress.")
+ return
+
try:
- msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = wgp.apply_changes(self.state, **changes)
- self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.")
+ ui_settings = {}
+ list_widget = self.widgets['config_transformer_types']
+ ui_settings['transformer_types'] = [list_widget.item(i).data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
+ list_widget = self.widgets['config_preload_model_policy']
+ ui_settings['preload_model_policy'] = [list_widget.item(i).data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked]
+
+ ui_settings['model_hierarchy_type'] = self.widgets['config_model_hierarchy_type'].currentData()
+ ui_settings['fit_canvas'] = self.widgets['config_fit_canvas'].currentData()
+ ui_settings['attention_mode'] = self.widgets['config_attention_mode'].currentData()
+ ui_settings['metadata_type'] = self.widgets['config_metadata_type'].currentData()
+ ui_settings['clear_file_list'] = self.widgets['config_clear_file_list'].currentData()
+ ui_settings['display_stats'] = self.widgets['config_display_stats'].currentData()
+ ui_settings['max_frames_multiplier'] = self.widgets['config_max_frames_multiplier'].currentData()
+ ui_settings['checkpoints_paths'] = [p.strip() for p in self.widgets['config_checkpoints_paths'].toPlainText().replace("\r", "").split("\n") if p.strip()]
+ ui_settings['UI_theme'] = self.widgets['config_UI_theme'].currentData()
+ ui_settings['queue_color_scheme'] = self.widgets['config_queue_color_scheme'].currentData()
+
+ ui_settings['transformer_quantization'] = self.widgets['config_transformer_quantization'].currentData()
+ ui_settings['transformer_dtype_policy'] = self.widgets['config_transformer_dtype_policy'].currentData()
+ ui_settings['mixed_precision'] = self.widgets['config_mixed_precision'].currentData()
+ ui_settings['text_encoder_quantization'] = self.widgets['config_text_encoder_quantization'].currentData()
+ ui_settings['vae_precision'] = self.widgets['config_vae_precision'].currentData()
+ ui_settings['compile'] = self.widgets['config_compile'].currentData()
+ ui_settings['depth_anything_v2_variant'] = self.widgets['config_depth_anything_v2_variant'].currentData()
+ ui_settings['vae_config'] = self.widgets['config_vae_config'].currentData()
+ ui_settings['boost'] = self.widgets['config_boost'].currentData()
+ ui_settings['profile'] = self.widgets['config_profile'].currentData()
+ ui_settings['preload_in_VRAM'] = self.widgets['config_preload_in_VRAM'].value()
+
+ ui_settings['enhancer_enabled'] = self.widgets['config_enhancer_enabled'].currentData()
+ ui_settings['enhancer_mode'] = self.widgets['config_enhancer_mode'].currentData()
+ ui_settings['mmaudio_enabled'] = self.widgets['config_mmaudio_enabled'].currentData()
+
+ ui_settings['video_output_codec'] = self.widgets['config_video_output_codec'].currentData()
+ ui_settings['image_output_codec'] = self.widgets['config_image_output_codec'].currentData()
+ ui_settings['audio_output_codec'] = self.widgets['config_audio_output_codec'].currentData()
+ ui_settings['embed_source_images'] = self.widgets['config_embed_source_images'].isChecked()
+ ui_settings['save_path'] = self.widgets['config_save_path'].text()
+ ui_settings['image_save_path'] = self.widgets['config_image_save_path'].text()
+
+ ui_settings['notification_sound_enabled'] = self.widgets['config_notification_sound_enabled'].currentData()
+ ui_settings['notification_sound_volume'] = self.widgets['config_notification_sound_volume'].value()
+
+ ui_settings['last_model_type'] = self.state["model_type"]
+ ui_settings['last_model_per_family'] = self.state["last_model_per_family"]
+ ui_settings['last_model_per_type'] = self.state["last_model_per_type"]
+ ui_settings['last_advanced_choice'] = self.state["advanced"]
+ ui_settings['last_resolution_choice'] = self.widgets['resolution'].currentData()
+ ui_settings['last_resolution_per_group'] = self.state["last_resolution_per_group"]
+
+ wgp.server_config.update(ui_settings)
+
+ wgp.fl.set_checkpoints_paths(ui_settings['checkpoints_paths'])
+ wgp.three_levels_hierarchy = ui_settings["model_hierarchy_type"] == 1
+ wgp.attention_mode = ui_settings["attention_mode"]
+ wgp.default_profile = ui_settings["profile"]
+ wgp.compile = ui_settings["compile"]
+ wgp.text_encoder_quantization = ui_settings["text_encoder_quantization"]
+ wgp.vae_config = ui_settings["vae_config"]
+ wgp.boost = ui_settings["boost"]
+ wgp.save_path = ui_settings["save_path"]
+ wgp.image_save_path = ui_settings["image_save_path"]
+ wgp.preload_model_policy = ui_settings["preload_model_policy"]
+ wgp.transformer_quantization = ui_settings["transformer_quantization"]
+ wgp.transformer_dtype_policy = ui_settings["transformer_dtype_policy"]
+ wgp.transformer_types = ui_settings["transformer_types"]
+ wgp.reload_needed = True
+
+ with open(wgp.server_config_filename, "w", encoding="utf-8") as writer:
+ json.dump(wgp.server_config, writer, indent=4)
+
+ self.config_status_label.setText("Settings saved successfully. Restart may be required for some changes.")
self.header_info.setText(wgp.generate_header(self.state['model_type'], wgp.compile, wgp.attention_mode))
- if family_mock.choices is not None or choice_mock.choices is not None:
- self.update_model_dropdowns(wgp.transformer_type)
- self.refresh_ui_from_model_change(wgp.transformer_type)
+ self.update_model_dropdowns(wgp.transformer_type)
+ self.refresh_ui_from_model_change(wgp.transformer_type)
+
except Exception as e:
self.config_status_label.setText(f"Error applying changes: {e}")
import traceback; traceback.print_exc()
From 73eac71c0e3963b4151dba1018ac2778bbcc6e43 Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 17 Oct 2025 23:28:50 +1100
Subject: [PATCH 154/155] fix queue index problems in plugin
---
plugins/wan2gp/main.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 95d0beeb3..0fe492479 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -179,7 +179,6 @@ def dropEvent(self, event: QDropEvent):
source_row = self.currentRow()
target_item = self.itemAt(event.position().toPoint())
dest_row = target_item.row() if target_item else self.rowCount()
- if source_row < dest_row: dest_row -=1
if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row)
event.acceptProposedAction()
else:
From f9f4bd1342fe9bd3a7a902a0b0dac1d63ee2492e Mon Sep 17 00:00:00 2001
From: Chris Malone
Date: Fri, 17 Oct 2025 23:50:04 +1100
Subject: [PATCH 155/155] fix queue width
---
plugins/wan2gp/main.py | 15 ++++++++++-----
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py
index 0fe492479..f91236d0c 100644
--- a/plugins/wan2gp/main.py
+++ b/plugins/wan2gp/main.py
@@ -563,9 +563,18 @@ def setup_generator_tab(self):
results_layout.addWidget(self.results_list)
right_layout.addWidget(results_group)
- right_layout.addWidget(QLabel("Queue:"))
+ right_layout.addWidget(QLabel("Queue"))
self.queue_table = self.create_widget(QueueTableWidget, 'queue_table')
+ self.queue_table.verticalHeader().setVisible(False)
+ self.queue_table.setColumnCount(4)
+ self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"])
+ header = self.queue_table.horizontalHeader()
+ header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
+ header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
+ header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents)
+ header.setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents)
right_layout.addWidget(self.queue_table)
+
queue_btn_layout = QHBoxLayout()
self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected")
self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue")
@@ -1594,14 +1603,10 @@ def update_queue_table(self):
table_data = wgp.get_queue_table(queue_to_display)
self.queue_table.setRowCount(0)
self.queue_table.setRowCount(len(table_data))
- self.queue_table.setColumnCount(4)
- self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"])
for row_idx, row_data in enumerate(table_data):
prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1])
for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]):
self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data)))
- self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch)
- self.queue_table.resizeColumnsToContents()
def _on_remove_selected_from_queue(self):
selected_row = self.queue_table.currentRow()