From 0ae1ae54a8facd289d6e8b7f842b02f1b7f12f1a Mon Sep 17 00:00:00 2001 From: Christopher Anderson Date: Fri, 20 Jun 2025 21:27:01 +1000 Subject: [PATCH 001/155] AMD installation instructions for RDNA3.x+Windows --- docs/AMD-INSTALLATION.md | 146 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 docs/AMD-INSTALLATION.md diff --git a/docs/AMD-INSTALLATION.md b/docs/AMD-INSTALLATION.md new file mode 100644 index 000000000..4f05589eb --- /dev/null +++ b/docs/AMD-INSTALLATION.md @@ -0,0 +1,146 @@ +# Installation Guide + +This guide covers installation for specific RDNA3 and RDNA3.5 AMD CPUs (APUs) and GPUs +running under Windows. + +tl;dr: Radeon RX 7900 GOOD, RX 9700 BAD, RX 6800 BAD. (I know, life isn't fair). + +Currently supported (but not necessary tested): + +**gfx110x**: + +* Radeon RX 7600 +* Radeon RX 7700 XT +* Radeon RX 7800 XT +* Radeon RX 7900 GRE +* Radeon RX 7900 XT +* Radeon RX 7900 XTX + +**gfx1151**: + +* Ryzen 7000 series APUs (Phoenix) +* Ryzen Z1 (e.g., handheld devices like the ROG Ally) + +**gfx1201**: + +* Ryzen 8000 series APUs (Strix Point) +* A [frame.work](https://frame.work/au/en/desktop) desktop/laptop + + +## Requirements + +- Python 3.11 (3.12 might work, 3.10 definately will not!) + +## Installation Environment + +This installation uses PyTorch 2.7.0 because that's what currently available in +terms of pre-compiled wheels. + +### Installing Python + +Download Python 3.11 from [python.org/downloads/windows](https://www.python.org/downloads/windows/). Hit Ctrl+F and search for "3.11". Dont use this direct link: [https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe) -- that was an IQ test. + +After installing, make sure `python --version` works in your terminal and returns 3.11.x + +If not, you probably need to fix your PATH. Go to: + +* Windows + Pause/Break +* Advanced System Settings +* Environment Variables +* Edit your `Path` under User Variables + +Example correct entries: + +```cmd +C:\Users\YOURNAME\AppData\Local\Programs\Python\Launcher\ +C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\Scripts\ +C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\ +``` + +If that doesnt work, scream into a bucket. + +### Installing Git + +Get Git from [git-scm.com/downloads/win](https://git-scm.com/downloads/win). Default install is fine. + + +## Install (Windows, using `venv`) + +### Step 1: Download and Set Up Environment + +```cmd +:: Navigate to your desired install directory +cd \your-path-to-wan2gp + +:: Clone the repository +git clone https://github.com/deepbeepmeep/Wan2GP.git +cd Wan2GP + +:: Create virtual environment using Python 3.10.9 +python -m venv wan2gp-env + +:: Activate the virtual environment +wan2gp-env\Scripts\activate +``` + +### Step 2: Install PyTorch + +The pre-compiled wheels you need are hosted at [scottt's rocm-TheRock releases](https://github.com/scottt/rocm-TheRock/releases). Find the heading that says: + +**Pytorch wheels for gfx110x, gfx1151, and gfx1201** + +Don't click this link: [https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x](https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x). It's just here to check if you're skimming. + +Copy the links of the closest binaries to the ones in the example below (adjust if you're not running Python 3.11), then hit enter. + +```cmd +pip install ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl +``` + +### Step 3: Install Dependencies + +```cmd +:: Install core dependencies +pip install -r requirements.txt +``` + +## Attention Modes + +WanGP supports several attention implementations, only one of which will work for you: + +- **SDPA** (default): Available by default with PyTorch. This uses the built-in aotriton accel library, so is actually pretty fast. + +## Performance Profiles + +Choose a profile based on your hardware: + +- **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model +- **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement + +## Running Wan2GP + +In future, you will have to do this: + +```cmd +cd \path-to\wan2gp +wan2gp\Scripts\activate.bat +python wgp.py +``` + +For now, you should just be able to type `python wgp.py` (because you're already in the virtual environment) + +## Troubleshooting + +- If you use a HIGH VRAM mode, don't be a fool. Make sure you use VAE Tiled Decoding. + +### Memory Issues + +- Use lower resolution or shorter videos +- Enable quantization (default) +- Use Profile 4 for lower VRAM usage +- Consider using 1.3B models instead of 14B models + +For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) From ee170f17659d2750103773c91a766595d6393cf5 Mon Sep 17 00:00:00 2001 From: Ciprian Mandache Date: Wed, 9 Jul 2025 22:48:30 +0300 Subject: [PATCH 002/155] dockerize and add deb-based sys runner + builder --- Dockerfile | 92 +++++++++++++++++++++ README.md | 113 +++++++++++++++++++------ entrypoint.sh | 18 ++++ run-docker-cuda-deb.sh | 183 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 379 insertions(+), 27 deletions(-) create mode 100644 Dockerfile create mode 100755 entrypoint.sh create mode 100755 run-docker-cuda-deb.sh diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..927c579fd --- /dev/null +++ b/Dockerfile @@ -0,0 +1,92 @@ +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 + +# Build arg for GPU architectures - specify which CUDA compute capabilities to compile for +# Common values: +# 7.0 - Tesla V100 +# 7.5 - RTX 2060, 2070, 2080, Titan RTX +# 8.0 - A100, A800 (Ampere data center) +# 8.6 - RTX 3060, 3070, 3080, 3090 (Ampere consumer) +# 8.9 - RTX 4070, 4080, 4090 (Ada Lovelace) +# 9.0 - H100, H800 (Hopper data center) +# 12.0 - RTX 5070, 5080, 5090 (Blackwell) - Note: sm_120 architecture +# +# Examples: +# RTX 3060: --build-arg CUDA_ARCHITECTURES="8.6" +# RTX 4090: --build-arg CUDA_ARCHITECTURES="8.9" +# Multiple: --build-arg CUDA_ARCHITECTURES="8.0;8.6;8.9" +# +# Note: Including 8.9 or 9.0 may cause compilation issues on some setups +# Default includes 8.0 and 8.6 for broad Ampere compatibility +ARG CUDA_ARCHITECTURES="8.0;8.6" + +ENV DEBIAN_FRONTEND=noninteractive + +# Install system dependencies +RUN apt update && \ + apt install -y \ + python3 python3-pip git wget curl cmake ninja-build \ + libgl1 libglib2.0-0 ffmpeg && \ + apt clean + +WORKDIR /workspace + +COPY requirements.txt . + +# Upgrade pip first +RUN pip install --upgrade pip setuptools wheel + +# Install requirements if exists +RUN pip install -r requirements.txt + +# Install PyTorch with CUDA support +RUN pip install --extra-index-url https://download.pytorch.org/whl/cu124 \ + torch==2.6.0+cu124 torchvision==0.21.0+cu124 + +# Install SageAttention from git (patch GPU detection) +ENV TORCH_CUDA_ARCH_LIST="${CUDA_ARCHITECTURES}" +ENV FORCE_CUDA="1" +ENV MAX_JOBS="1" + +COPY < WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor

WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with: + - Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models) - Support for old GPUs (RTX 10XX, 20xx, ...) - Very Fast on the latest GPUs @@ -13,40 +15,46 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models - Auto download of the required model adapted to your specific architecture - Tools integrated to facilitate Video Generation : Mask Editor, Prompt Enhancer, Temporal and Spatial Generation, MMAudio, Video Browser, Pose / Depth / Flow extractor - Loras Support to customize each model -- Queuing system : make your shopping list of videos to generate and come back later +- Queuing system : make your shopping list of videos to generate and come back later **Discord Server to get Help from Other Users and show your Best Videos:** https://discord.gg/g7efUW9jGV **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates + ### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : -**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). -Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus. +**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to _Sliding Windows_ support and _Adaptive Projected Guidance_ (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). + +Of course you will get as well _Multitalk_ vanilla and also _Multitalk 720p_ as a bonus. -And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people. +And since I am mister nice guy I have enclosed as an exclusivity an _Audio Separator_ that will save you time to isolate each voice when using Multitalk with two people. -As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors. +As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the _Share Your Best Video_ channel your _Master Pieces_. The best ones will be added to the _Announcements Channel_ and will bring eternal fame to its authors. But wait, there is more: + - Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you) - I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes. Also, of interest too: + - Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos - Force the generated video fps to your liking, works wery well with Vace when using a Control Video - Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time) ### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features: + - View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations -- In one click use the newly generated video as a Control Video or Source Video to be continued -- Manage multiple settings for the same model and switch between them using a dropdown box +- In one click use the newly generated video as a Control Video or Source Video to be continued +- Manage multiple settings for the same model and switch between them using a dropdown box - WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open - Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder) Taking care of your life is not enough, you want new stuff to play with ? -- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio + +- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the _Extensions_ tab of the WanGP _Configuration_ to enable MMAudio - Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best - MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps - SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage @@ -58,24 +66,29 @@ Taking care of your life is not enough, you want new stuff to play with ? **If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one** ### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ? + - Multithreaded preprocessing when possible for faster generations - Multithreaded frames Lanczos Upsampling as a bonus -- A new Vace preprocessor : *Flow* to extract fluid motion -- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character. +- A new Vace preprocessor : _Flow_ to extract fluid motion +- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer _Human Movement_ and _Shapes_ at the same time for some reasons the lighting of your character will take into account much more the environment of your character. - Injected Frames Outpainting, in case you missed it in WanGP 6.21 Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated. - ### June 19 2025: WanGP v6.2, Vace even more Powercharged -👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: -- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time -- More processing can combined at the same time (for instance the depth process can be applied outside the mask) + +👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: + +- If you ever wanted to watch Star Wars in 4:3, just use the new _Outpainting_ feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is _Outpainting_ can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time +- More processing can combined at the same time (for instance the depth process can be applied outside the mask) - Upgraded the depth extractor to Depth Anything 2 which is much more detailed -As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server. +As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server. + ### June 17 2025: WanGP v6.1, Vace Powercharged + 👋 Lots of improvements for Vace the Mother of all Models: + - masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask - on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ... - view these modified masks directly inside WanGP during the video generation to check they are really as expected @@ -87,52 +100,62 @@ Of course all these new stuff work on all Vace finetunes (including Vace Fusioni Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary. ### June 12 2025: WanGP v6.0 -👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. + +👋 _Finetune models_: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu): -- *Fast Hunyuan Video* : generate model t2v in only 6 steps -- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps -- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps + +- _Fast Hunyuan Video_ : generate model t2v in only 6 steps +- _Hunyuan Vido AccVideo_ : generate model t2v in only 5 steps +- _Wan FusioniX_: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps One more thing... -The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? +The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune... -Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server. +Check the _Finetune Guide_ to create finetune models definitions and share them on the WanGP discord server. ### June 11 2025: WanGP v5.5 -👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ -*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... +👋 _Hunyuan Video Custom Audio_: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ +_Hunyuan Video Custom Edit_: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... ### June 6 2025: WanGP v5.41 + 👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\ -You will need to do a *pip install -r requirements.txt* +You will need to do a _pip install -r requirements.txt_ ### June 6 2025: WanGP v5.4 + 👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\ Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\ Also many thanks to Reevoy24 for his repackaging / completing the documentation ### May 28 2025: WanGP v5.31 + 👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets. VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options. ### May 26, 2025: WanGP v5.3 + 👋 Settings management revolution! Now you can: -- Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration + +- Select any generated video and click _Use Selected Video Settings_ to instantly reuse its configuration - Drag & drop videos to automatically extract their settings metadata - Export/import settings as JSON files for easy sharing and backup ### May 20, 2025: WanGP v5.2 + 👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid. ### May 18, 2025: WanGP v5.1 + 👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute! ### May 17, 2025: WanGP v5.0 + 👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer. See full changelog: **[Changelog](docs/CHANGELOG.md)** @@ -150,6 +173,7 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)** **One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/) **Manual installation:** + ```bash git clone https://github.com/deepbeepmeep/Wan2GP.git cd Wan2GP @@ -160,6 +184,7 @@ pip install -r requirements.txt ``` **Run the application:** + ```bash python wgp.py # Text-to-video (default) python wgp.py --i2v # Image-to-video @@ -168,6 +193,7 @@ python wgp.py --i2v # Image-to-video **Update the application:** If using Pinokio use Pinokio to update otherwise: Get in the directory where WanGP is installed and: + ```bash git pull pip install -r requirements.txt @@ -175,16 +201,48 @@ pip install -r requirements.txt ## 📦 Installation +### 🐳 Docker Installation + +**For Debian-based systems (Ubuntu, Debian, etc.):** + +```bash +./run-docker-cuda-deb.sh +``` + +This automated script will: + +- Detect your GPU model and VRAM automatically +- Select optimal CUDA architecture for your GPU +- Install NVIDIA Docker runtime if needed +- Build a Docker image with all dependencies +- Run WanGP with optimal settings for your hardware + +**Docker environment includes:** + +- NVIDIA CUDA 12.4.1 with cuDNN support +- PyTorch 2.6.0 with CUDA 12.4 support +- SageAttention compiled for your specific GPU architecture +- Optimized environment variables for performance (TF32, threading, etc.) +- Automatic cache directory mounting for faster subsequent runs +- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally + +**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more. + +### Manual Installation + For detailed installation instructions for different GPU generations: + - **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX ## 🎯 Usage ### Basic Usage + - **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage - **[Models Overview](docs/MODELS.md)** - Available models and their capabilities ### Advanced Features + - **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization - **[Finetunes](docs/FINETUNES.md)** - Add manually new models to WanGP - **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation @@ -198,6 +256,7 @@ For detailed installation instructions for different GPU generations: ## 🔗 Related Projects ### Other Models for the GPU Poor + - **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators - **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool - **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux @@ -209,4 +268,4 @@ For detailed installation instructions for different GPU generations:

Made with ❤️ by DeepBeepMeep -

+

diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 000000000..4c363abd1 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +export HOME=/home/user +export PYTHONUNBUFFERED=1 +export HF_HOME=/home/user/.cache/huggingface + +export OMP_NUM_THREADS=$(nproc) +export MKL_NUM_THREADS=$(nproc) +export OPENBLAS_NUM_THREADS=$(nproc) +export NUMEXPR_NUM_THREADS=$(nproc) + +export TORCH_ALLOW_TF32_CUBLAS=1 +export TORCH_ALLOW_TF32_CUDNN=1 + +# Disable audio warnings in Docker +export SDL_AUDIODRIVER=dummy +export PULSE_RUNTIME_PATH=/tmp/pulse-runtime + +exec su -p user -c "python3 wgp.py --listen $*" diff --git a/run-docker-cuda-deb.sh b/run-docker-cuda-deb.sh new file mode 100755 index 000000000..1a6201aa4 --- /dev/null +++ b/run-docker-cuda-deb.sh @@ -0,0 +1,183 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ───────────────────────── helpers ───────────────────────── + +install_nvidia_smi_if_missing() { + if command -v nvidia-smi &>/dev/null; then + return + fi + + echo "⚠️ nvidia-smi not found. Installing nvidia-utils…" + if [ "$EUID" -ne 0 ]; then + SUDO='sudo' + else + SUDO='' + fi + + $SUDO apt-get update + $SUDO apt-get install -y nvidia-utils-535 || $SUDO apt-get install -y nvidia-utils + + if ! command -v nvidia-smi &>/dev/null; then + echo "❌ Failed to install nvidia-smi. Cannot detect GPU architecture." + exit 1 + fi + echo "✅ nvidia-smi installed successfully." +} + +detect_gpu_name() { + install_nvidia_smi_if_missing + nvidia-smi --query-gpu=name --format=csv,noheader,nounits | head -1 +} + +map_gpu_to_arch() { + local name="$1" + case "$name" in + *"RTX 50"* | *"5090"* | *"5080"* | *"5070"*) echo "12.0" ;; + *"H100"* | *"H800"*) echo "9.0" ;; + *"RTX 40"* | *"4090"* | *"4080"* | *"4070"* | *"4060"*) echo "8.9" ;; + *"RTX 30"* | *"3090"* | *"3080"* | *"3070"* | *"3060"*) echo "8.6" ;; + *"A100"* | *"A800"* | *"A40"*) echo "8.0" ;; + *"Tesla V100"*) echo "7.0" ;; + *"RTX 20"* | *"2080"* | *"2070"* | *"2060"* | *"Titan RTX"*) echo "7.5" ;; + *"GTX 16"* | *"1660"* | *"1650"*) echo "7.5" ;; + *"GTX 10"* | *"1080"* | *"1070"* | *"1060"* | *"Tesla P100"*) echo "6.1" ;; + *"Tesla K80"* | *"Tesla K40"*) echo "3.7" ;; + *) + echo "❌ Unknown GPU model: $name" + echo "Please update the map_gpu_to_arch function for this model." + exit 1 + ;; + esac +} + +get_gpu_vram() { + install_nvidia_smi_if_missing + # Get VRAM in MB, convert to GB + local vram_mb=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | head -1) + echo $((vram_mb / 1024)) +} + +map_gpu_to_profile() { + local name="$1" + local vram_gb="$2" + + # WanGP Profile descriptions from the actual UI: + # Profile 1: HighRAM_HighVRAM - 48GB+ RAM, 24GB+ VRAM (fastest for short videos, RTX 3090/4090) + # Profile 2: HighRAM_LowVRAM - 48GB+ RAM, 12GB+ VRAM (recommended, most versatile) + # Profile 3: LowRAM_HighVRAM - 32GB+ RAM, 24GB+ VRAM (RTX 3090/4090 with limited RAM) + # Profile 4: LowRAM_LowVRAM - 32GB+ RAM, 12GB+ VRAM (default, little VRAM or longer videos) + # Profile 5: VerylowRAM_LowVRAM - 16GB+ RAM, 10GB+ VRAM (fail safe, slow but works) + + case "$name" in + # High-end data center GPUs with 24GB+ VRAM - Profile 1 (HighRAM_HighVRAM) + *"RTX 50"* | *"5090"* | *"A100"* | *"A800"* | *"H100"* | *"H800"*) + if [ "$vram_gb" -ge 24 ]; then + echo "1" # HighRAM_HighVRAM - fastest for short videos + else + echo "2" # HighRAM_LowVRAM - most versatile + fi + ;; + # High-end consumer GPUs (RTX 3090/4090) - Profile 1 or 3 + *"RTX 40"* | *"4090"* | *"RTX 30"* | *"3090"*) + if [ "$vram_gb" -ge 24 ]; then + echo "3" # LowRAM_HighVRAM - good for limited RAM systems + else + echo "2" # HighRAM_LowVRAM - most versatile + fi + ;; + # Mid-range GPUs (RTX 3070/3080/4070/4080) - Profile 2 recommended + *"4080"* | *"4070"* | *"3080"* | *"3070"* | *"RTX 20"* | *"2080"* | *"2070"*) + if [ "$vram_gb" -ge 12 ]; then + echo "2" # HighRAM_LowVRAM - recommended for these GPUs + else + echo "4" # LowRAM_LowVRAM - default for little VRAM + fi + ;; + # Lower-end GPUs with 6-12GB VRAM - Profile 4 or 5 + *"4060"* | *"3060"* | *"2060"* | *"GTX 16"* | *"1660"* | *"1650"*) + if [ "$vram_gb" -ge 10 ]; then + echo "4" # LowRAM_LowVRAM - default + else + echo "5" # VerylowRAM_LowVRAM - fail safe + fi + ;; + # Older/lower VRAM GPUs - Profile 5 (fail safe) + *"GTX 10"* | *"1080"* | *"1070"* | *"1060"* | *"Tesla"*) + echo "5" # VerylowRAM_LowVRAM - fail safe + ;; + *) + echo "4" # LowRAM_LowVRAM - default fallback + ;; + esac +} + +# ───────────────────────── main ──────────────────────────── + +GPU_NAME=$(detect_gpu_name) +echo "🔍 Detected GPU: $GPU_NAME" + +VRAM_GB=$(get_gpu_vram) +echo "🧠 Detected VRAM: ${VRAM_GB}GB" + +CUDA_ARCH=$(map_gpu_to_arch "$GPU_NAME") +echo "🚀 Using CUDA architecture: $CUDA_ARCH" + +PROFILE=$(map_gpu_to_profile "$GPU_NAME" "$VRAM_GB") +echo "⚙️ Selected profile: $PROFILE" + +docker build --build-arg CUDA_ARCHITECTURES="$CUDA_ARCH" -t deepbeepmeep/wan2gp . + +# sudo helper for later commands +if [ "$EUID" -ne 0 ]; then + SUDO='sudo' +else + SUDO='' +fi + +# Ensure NVIDIA runtime is available +if ! docker info 2>/dev/null | grep -q 'Runtimes:.*nvidia'; then + echo "⚠️ NVIDIA Docker runtime not found. Installing nvidia-docker2…" + $SUDO apt-get update + $SUDO apt-get install -y curl ca-certificates gnupg + curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | $SUDO apt-key add - + distribution=$( + . /etc/os-release + echo $ID$VERSION_ID + ) + curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | + $SUDO tee /etc/apt/sources.list.d/nvidia-docker.list + $SUDO apt-get update + $SUDO apt-get install -y nvidia-docker2 + echo "🔄 Restarting Docker service…" + $SUDO systemctl restart docker + echo "✅ NVIDIA Docker runtime installed." +else + echo "✅ NVIDIA Docker runtime found." +fi + +# Prepare cache dirs & volume mounts +cache_dirs=(numba matplotlib huggingface torch) +cache_mounts=() +for d in "${cache_dirs[@]}"; do + mkdir -p "$HOME/.cache/$d" + chmod 700 "$HOME/.cache/$d" + cache_mounts+=(-v "$HOME/.cache/$d:/home/user/.cache/$d") +done + +echo "🔧 Optimization settings:" +echo " Profile: $PROFILE" + +# Run the container +docker run --rm -it \ + --name wan2gp \ + --gpus all \ + --runtime=nvidia \ + -p 7860:7860 \ + -v "$(pwd):/workspace" \ + "${cache_mounts[@]}" \ + deepbeepmeep/wan2gp \ + --profile "$PROFILE" \ + --attention sage \ + --compile \ + --perc-reserved-mem-max 1 From 3307defa7c005e8566170d0784254aa9829a1759 Mon Sep 17 00:00:00 2001 From: Ciprian Mandache Date: Sat, 30 Aug 2025 06:51:17 +0300 Subject: [PATCH 003/155] add more checks + logging + upgrade peft --- entrypoint.sh | 100 +++++++++++++++++++++++++++++++++++++++++ requirements.txt | 8 ++-- run-docker-cuda-deb.sh | 27 +++++++++++ 3 files changed, 131 insertions(+), 4 deletions(-) diff --git a/entrypoint.sh b/entrypoint.sh index 4c363abd1..9af052d07 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -15,4 +15,104 @@ export TORCH_ALLOW_TF32_CUDNN=1 export SDL_AUDIODRIVER=dummy export PULSE_RUNTIME_PATH=/tmp/pulse-runtime +# ═══════════════════════════ CUDA DEBUG CHECKS ═══════════════════════════ + +echo "🔍 CUDA Environment Debug Information:" +echo "═══════════════════════════════════════════════════════════════════════" + +# Check CUDA driver on host (if accessible) +if command -v nvidia-smi >/dev/null 2>&1; then + echo "✅ nvidia-smi available" + echo "📊 GPU Information:" + nvidia-smi --query-gpu=name,driver_version,memory.total,memory.free --format=csv,noheader,nounits 2>/dev/null || echo "❌ nvidia-smi failed to query GPU" + echo "🏃 Running Processes:" + nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader,nounits 2>/dev/null || echo "ℹ️ No running CUDA processes" +else + echo "❌ nvidia-smi not available in container" +fi + +# Check CUDA runtime libraries +echo "" +echo "🔧 CUDA Runtime Check:" +if ls /usr/local/cuda*/lib*/libcudart.so* >/dev/null 2>&1; then + echo "✅ CUDA runtime libraries found:" + ls /usr/local/cuda*/lib*/libcudart.so* 2>/dev/null +else + echo "❌ CUDA runtime libraries not found" +fi + +# Check CUDA devices +echo "" +echo "🖥️ CUDA Device Files:" +if ls /dev/nvidia* >/dev/null 2>&1; then + echo "✅ NVIDIA device files found:" + ls -la /dev/nvidia* 2>/dev/null +else + echo "❌ No NVIDIA device files found - Docker may not have GPU access" +fi + +# Check CUDA environment variables +echo "" +echo "🌍 CUDA Environment Variables:" +echo " CUDA_HOME: ${CUDA_HOME:-not set}" +echo " CUDA_ROOT: ${CUDA_ROOT:-not set}" +echo " CUDA_PATH: ${CUDA_PATH:-not set}" +echo " LD_LIBRARY_PATH: ${LD_LIBRARY_PATH:-not set}" +echo " TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-not set}" +echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-not set}" + +# Check PyTorch CUDA availability +echo "" +echo "🐍 PyTorch CUDA Check:" +python3 -c " +import sys +try: + import torch + print('✅ PyTorch imported successfully') + print(f' Version: {torch.__version__}') + print(f' CUDA available: {torch.cuda.is_available()}') + if torch.cuda.is_available(): + print(f' CUDA version: {torch.version.cuda}') + print(f' cuDNN version: {torch.backends.cudnn.version()}') + print(f' Device count: {torch.cuda.device_count()}') + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + print(f' Device {i}: {props.name} (SM {props.major}.{props.minor}, {props.total_memory//1024//1024}MB)') + else: + print('❌ CUDA not available to PyTorch') + print(' This could mean:') + print(' - CUDA runtime not properly installed') + print(' - GPU not accessible to container') + print(' - Driver/runtime version mismatch') +except ImportError as e: + print(f'❌ Failed to import PyTorch: {e}') +except Exception as e: + print(f'❌ PyTorch CUDA check failed: {e}') +" 2>&1 + +# Check for common CUDA issues +echo "" +echo "🩺 Common Issue Diagnostics:" + +# Check if running with proper Docker flags +if [ ! -e /dev/nvidia0 ] && [ ! -e /dev/nvidiactl ]; then + echo "❌ No NVIDIA device nodes - container likely missing --gpus all or --runtime=nvidia" +fi + +# Check CUDA library paths +if [ -z "$LD_LIBRARY_PATH" ] || ! echo "$LD_LIBRARY_PATH" | grep -q cuda; then + echo "⚠️ LD_LIBRARY_PATH may not include CUDA libraries" +fi + +# Check permissions on device files +if ls /dev/nvidia* >/dev/null 2>&1; then + if ! ls -la /dev/nvidia* | grep -q "rw-rw-rw-\|rw-r--r--"; then + echo "⚠️ NVIDIA device files may have restrictive permissions" + fi +fi + +echo "═══════════════════════════════════════════════════════════════════════" +echo "🚀 Starting application..." +echo "" + exec su -p user -c "python3 wgp.py --listen $*" diff --git a/requirements.txt b/requirements.txt index ea0c582f6..312dc9400 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,13 +12,13 @@ easydict ftfy dashscope imageio-ffmpeg -# flash_attn -gradio==5.23.0 +# flash_attn +gradio==5.23.0 numpy>=1.23.5,<2 einops moviepy==1.0.3 mmgp==3.5.1 -peft==0.15.0 +peft==0.17.0 mutagen pydantic==2.10.6 decord @@ -46,4 +46,4 @@ soundfile ffmpeg-python pyannote.audio # num2words -# spacy \ No newline at end of file +# spacy diff --git a/run-docker-cuda-deb.sh b/run-docker-cuda-deb.sh index 1a6201aa4..b35e9cc6b 100755 --- a/run-docker-cuda-deb.sh +++ b/run-docker-cuda-deb.sh @@ -114,6 +114,25 @@ map_gpu_to_profile() { # ───────────────────────── main ──────────────────────────── +echo "🔧 NVIDIA CUDA Setup Check:" + +# NVIDIA driver check +if command -v nvidia-smi &>/dev/null; then + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader,nounits | head -1) + echo "✅ NVIDIA Driver: $DRIVER_VERSION" + + # Quick CUDA 12.4 compatibility check + if [[ "$DRIVER_VERSION" =~ ^([0-9]+) ]]; then + MAJOR=${BASH_REMATCH[1]} + if [ "$MAJOR" -lt 520 ]; then + echo "⚠️ Driver $DRIVER_VERSION may not support CUDA 12.4 (need 520+)" + fi + fi +else + echo "❌ nvidia-smi not found - no NVIDIA drivers" + exit 1 +fi + GPU_NAME=$(detect_gpu_name) echo "🔍 Detected GPU: $GPU_NAME" @@ -156,6 +175,14 @@ else echo "✅ NVIDIA Docker runtime found." fi +# Quick NVIDIA runtime test +echo "🧪 Testing NVIDIA runtime..." +if timeout 15s docker run --rm --gpus all --runtime=nvidia nvidia/cuda:12.4-runtime-ubuntu22.04 nvidia-smi >/dev/null 2>&1; then + echo "✅ NVIDIA runtime working" +else + echo "❌ NVIDIA runtime test failed - check driver/runtime compatibility" +fi + # Prepare cache dirs & volume mounts cache_dirs=(numba matplotlib huggingface torch) cache_mounts=() From 808ef51688dfd30288c3a99f369448d0e76b9a1f Mon Sep 17 00:00:00 2001 From: Ciprian Mandache Date: Sat, 30 Aug 2025 07:02:35 +0300 Subject: [PATCH 004/155] fix readme --- README.md | 133 ++++++++++++++++++++---------------------------------- 1 file changed, 49 insertions(+), 84 deletions(-) diff --git a/README.md b/README.md index 50c03ee2e..fb04ce13d 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,11 @@ # WanGP ---- - +-----

WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor

WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with: - - Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models) - Support for old GPUs (RTX 10XX, 20xx, ...) - Very Fast on the latest GPUs @@ -45,39 +43,31 @@ Also in the news: ### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. -### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me - -Maybe you knew that already but most _Loras accelerators_ we use today (Causvid, FusioniX) don't use _Guidance_ at all (that it is _CFG_ is set to 1). This helps to get much faster generations but the downside is that _Negative Prompts_ are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the _Negative Prompt_ during the _attention_ processing phase. - -So WanGP 6.7 gives you NAG, but not any NAG, a _Low VRAM_ implementation, the default one ends being VRAM greedy. You will find NAG in the _General_ advanced tab for most Wan models. +So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models. Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters. The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence. ### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : +**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). -**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to _Sliding Windows_ support and _Adaptive Projected Guidance_ (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). +Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus. -Of course you will get as well _Multitalk_ vanilla and also _Multitalk 720p_ as a bonus. +And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people. -And since I am mister nice guy I have enclosed as an exclusivity an _Audio Separator_ that will save you time to isolate each voice when using Multitalk with two people. - -As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the _Share Your Best Video_ channel your _Master Pieces_. The best ones will be added to the _Announcements Channel_ and will bring eternal fame to its authors. +As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors. But wait, there is more: - - Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you) - I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes. Also, of interest too: - - Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos - Force the generated video fps to your liking, works wery well with Vace when using a Control Video - Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time) ### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features: - - View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations - In one click use the newly generated video as a Control Video or Source Video to be continued - Manage multiple settings for the same model and switch between them using a dropdown box @@ -85,8 +75,7 @@ Also, of interest too: - Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder) Taking care of your life is not enough, you want new stuff to play with ? - -- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the _Extensions_ tab of the WanGP _Configuration_ to enable MMAudio +- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio - Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best - MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps - SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage @@ -98,29 +87,24 @@ Taking care of your life is not enough, you want new stuff to play with ? **If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one** ### June 23 2025: WanGP v6.3, Vace Unleashed. Thought we couldnt squeeze Vace even more ? - - Multithreaded preprocessing when possible for faster generations - Multithreaded frames Lanczos Upsampling as a bonus -- A new Vace preprocessor : _Flow_ to extract fluid motion -- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer _Human Movement_ and _Shapes_ at the same time for some reasons the lighting of your character will take into account much more the environment of your character. +- A new Vace preprocessor : *Flow* to extract fluid motion +- Multi Vace Controlnets: you can now transfer several properties at the same time. This opens new possibilities to explore, for instance if you transfer *Human Movement* and *Shapes* at the same time for some reasons the lighting of your character will take into account much more the environment of your character. - Injected Frames Outpainting, in case you missed it in WanGP 6.21 Don't know how to use all of the Vace features ? Check the Vace Guide embedded in WanGP as it has also been updated. -### June 19 2025: WanGP v6.2, Vace even more Powercharged +### June 19 2025: WanGP v6.2, Vace even more Powercharged 👋 Have I told you that I am a big fan of Vace ? Here are more goodies to unleash its power: - -- If you ever wanted to watch Star Wars in 4:3, just use the new _Outpainting_ feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is _Outpainting_ can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time -- More processing can combined at the same time (for instance the depth process can be applied outside the mask) +- If you ever wanted to watch Star Wars in 4:3, just use the new *Outpainting* feature and it will add the missing bits of image at the top and the bottom of the screen. The best thing is *Outpainting* can be combined with all the other Vace modifications, for instance you can change the main character of your favorite movie at the same time +- More processing can combined at the same time (for instance the depth process can be applied outside the mask) - Upgraded the depth extractor to Depth Anything 2 which is much more detailed As a bonus, I have added two finetunes based on the Safe-Forcing technology (which requires only 4 steps to generate a video): Wan 2.1 text2video Self-Forcing and Vace Self-Forcing. I know there is Lora around but the quality of the Lora is worse (at least with Vace) compared to the full model. Don't hesitate to share your opinion about this on the discord server. - ### June 17 2025: WanGP v6.1, Vace Powercharged - 👋 Lots of improvements for Vace the Mother of all Models: - - masks can now be combined with on the fly processing of a control video, for instance you can extract the motion of a specific person defined by a mask - on the fly modification of masks : reversed masks (with the same mask you can modify the background instead of the people covered by the masks), enlarged masks (you can cover more area if for instance the person you are trying to inject is larger than the one in the mask), ... - view these modified masks directly inside WanGP during the video generation to check they are really as expected @@ -132,62 +116,52 @@ Of course all these new stuff work on all Vace finetunes (including Vace Fusioni Thanks also to Reevoy24 for adding a Notfication sound at the end of a generation and for fixing the background color of the current generation summary. ### June 12 2025: WanGP v6.0 - -👋 _Finetune models_: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. +👋 *Finetune models*: You find the 20 models supported by WanGP not sufficient ? Too impatient to wait for the next release to get the support for a newly released model ? Your prayers have been answered: if a new model is compatible with a model architecture supported by WanGP, you can add by yourself the support for this model in WanGP by just creating a finetune model definition. You can then store this model in the cloud (for instance in Huggingface) and the very light finetune definition file can be easily shared with other users. WanGP will download automatically the finetuned model for them. To celebrate the new finetunes support, here are a few finetune gifts (directly accessible from the model selection menu): - -- _Fast Hunyuan Video_ : generate model t2v in only 6 steps -- _Hunyuan Vido AccVideo_ : generate model t2v in only 5 steps -- _Wan FusioniX_: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps +- *Fast Hunyuan Video* : generate model t2v in only 6 steps +- *Hunyuan Vido AccVideo* : generate model t2v in only 5 steps +- *Wan FusioniX*: it is a combo of AccVideo / CausVid ans other models and can generate high quality Wan videos in only 8 steps One more thing... -The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? +The new finetune system can be used to combine complementaty models : what happens when you combine Fusionix Text2Video and Vace Control Net ? You get **Vace FusioniX**: the Ultimate Vace Model, Fast (10 steps, no need for guidance) and with a much better quality Video than the original slower model (despite being the best Control Net out there). Here goes one more finetune... -Check the _Finetune Guide_ to create finetune models definitions and share them on the WanGP discord server. +Check the *Finetune Guide* to create finetune models definitions and share them on the WanGP discord server. ### June 11 2025: WanGP v5.5 +👋 *Hunyuan Video Custom Audio*: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ +*Hunyuan Video Custom Edit*: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... -👋 _Hunyuan Video Custom Audio_: it is similar to Hunyuan Video Avatar except there isn't any lower limit on the number of frames and you can use your reference images in a different context than the image itself\ -_Hunyuan Video Custom Edit_: Hunyuan Video Controlnet, use it to do inpainting and replace a person in a video while still keeping his poses. Similar to Vace but less restricted than the Wan models in terms of content... ### June 6 2025: WanGP v5.41 - 👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\ -You will need to do a _pip install -r requirements.txt_ +You will need to do a *pip install -r requirements.txt* ### June 6 2025: WanGP v5.4 - 👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\ Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\ Also many thanks to Reevoy24 for his repackaging / completing the documentation ### May 28 2025: WanGP v5.31 - 👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets. VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options. ### May 26, 2025: WanGP v5.3 - 👋 Settings management revolution! Now you can: - -- Select any generated video and click _Use Selected Video Settings_ to instantly reuse its configuration +- Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration - Drag & drop videos to automatically extract their settings metadata - Export/import settings as JSON files for easy sharing and backup ### May 20, 2025: WanGP v5.2 - 👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid. ### May 18, 2025: WanGP v5.1 - 👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute! ### May 17, 2025: WanGP v5.0 - 👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer. See full changelog: **[Changelog](docs/CHANGELOG.md)** @@ -202,10 +176,36 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)** ## 🚀 Quick Start +### 🐳 Docker: + +**For Debian-based systems (Ubuntu, Debian, etc.):** + +```bash +./run-docker-cuda-deb.sh +``` + +This automated script will: + +- Detect your GPU model and VRAM automatically +- Select optimal CUDA architecture for your GPU +- Install NVIDIA Docker runtime if needed +- Build a Docker image with all dependencies +- Run WanGP with optimal settings for your hardware + +**Docker environment includes:** + +- NVIDIA CUDA 12.4.1 with cuDNN support +- PyTorch 2.6.0 with CUDA 12.4 support +- SageAttention compiled for your specific GPU architecture +- Optimized environment variables for performance (TF32, threading, etc.) +- Automatic cache directory mounting for faster subsequent runs +- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally + +**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more. + **One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/) **Manual installation:** - ```bash git clone https://github.com/deepbeepmeep/Wan2GP.git cd Wan2GP @@ -216,7 +216,6 @@ pip install -r requirements.txt ``` **Run the application:** - ```bash python wgp.py # Text-to-video (default) python wgp.py --i2v # Image-to-video @@ -225,7 +224,6 @@ python wgp.py --i2v # Image-to-video **Update the application:** If using Pinokio use Pinokio to update otherwise: Get in the directory where WanGP is installed and: - ```bash git pull pip install -r requirements.txt @@ -233,48 +231,16 @@ pip install -r requirements.txt ## 📦 Installation -### 🐳 Docker Installation - -**For Debian-based systems (Ubuntu, Debian, etc.):** - -```bash -./run-docker-cuda-deb.sh -``` - -This automated script will: - -- Detect your GPU model and VRAM automatically -- Select optimal CUDA architecture for your GPU -- Install NVIDIA Docker runtime if needed -- Build a Docker image with all dependencies -- Run WanGP with optimal settings for your hardware - -**Docker environment includes:** - -- NVIDIA CUDA 12.4.1 with cuDNN support -- PyTorch 2.6.0 with CUDA 12.4 support -- SageAttention compiled for your specific GPU architecture -- Optimized environment variables for performance (TF32, threading, etc.) -- Automatic cache directory mounting for faster subsequent runs -- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally - -**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more. - -### Manual Installation - For detailed installation instructions for different GPU generations: - - **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX ## 🎯 Usage ### Basic Usage - - **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage - **[Models Overview](docs/MODELS.md)** - Available models and their capabilities ### Advanced Features - - **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization - **[Finetunes](docs/FINETUNES.md)** - Add manually new models to WanGP - **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation @@ -288,7 +254,6 @@ For detailed installation instructions for different GPU generations: ## 🔗 Related Projects ### Other Models for the GPU Poor - - **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators - **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool - **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux From 1ae91d36f1d7135e326b8c1129584482c69df44c Mon Sep 17 00:00:00 2001 From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com> Date: Tue, 2 Sep 2025 02:50:49 +0200 Subject: [PATCH 005/155] fixed InfiniteTalk starting from last frame --- wgp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wgp.py b/wgp.py index a151252ba..0345e642e 100644 --- a/wgp.py +++ b/wgp.py @@ -4739,7 +4739,7 @@ def remove_temp_filenames(temp_filenames_list): raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] if infinitetalk and video_guide is not None: - src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True) + src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True) new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size) src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) refresh_preview["video_guide"] = src_image @@ -9290,4 +9290,4 @@ def create_ui(): else: url = "http://" + server_name webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True) - demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path})) \ No newline at end of file + demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path})) From 898b542cc6df6246b614a5ce93135482517062ff Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 2 Sep 2025 20:39:31 +0200 Subject: [PATCH 006/155] pain reliever --- README.md | 11 + models/hyvideo/hunyuan_handler.py | 2 +- models/hyvideo/modules/utils.py | 2 +- models/qwen/transformer_qwenimage.py | 3 +- models/wan/any2video.py | 11 +- models/wan/modules/model.py | 27 +- models/wan/multitalk/attention.py | 26 +- models/wan/multitalk/multitalk_utils.py | 111 ++++-- models/wan/wan_handler.py | 22 ++ requirements.txt | 5 +- shared/attention.py | 73 +++- shared/gradio/gallery.py | 496 ++++++++++++++++++++++++ shared/utils/qwen_vl_utils.py | 2 - wgp.py | 190 ++++----- 14 files changed, 799 insertions(+), 182 deletions(-) create mode 100644 shared/gradio/gallery.py diff --git a/README.md b/README.md index 47736da47..be76cbcc8 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,17 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### September 2 2025: WanGP v8.3 - At last the pain stops + +- This single new feature should give you the strength to face all the potential bugs of this new release: +**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.** + +- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p). + +- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash. + + + ### August 29 2025: WanGP v8.21 - Here Goes Your Weekend - **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index d95bd7e35..da60f72cb 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -53,7 +53,7 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True - if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: + if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]: extra_model_def["one_image_ref_needed"] = True return extra_model_def diff --git a/models/hyvideo/modules/utils.py b/models/hyvideo/modules/utils.py index 02a733e1b..c26399728 100644 --- a/models/hyvideo/modules/utils.py +++ b/models/hyvideo/modules/utils.py @@ -14,7 +14,7 @@ ) -@lru_cache +# @lru_cache def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False): block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile) return block_mask diff --git a/models/qwen/transformer_qwenimage.py b/models/qwen/transformer_qwenimage.py index 6d908064a..8ec34ed01 100644 --- a/models/qwen/transformer_qwenimage.py +++ b/models/qwen/transformer_qwenimage.py @@ -204,7 +204,7 @@ def forward(self, video_fhw, txt_seq_lens, device): frame, height, width = fhw rope_key = f"{idx}_{height}_{width}" - if not torch.compiler.is_compiling(): + if not torch.compiler.is_compiling() and False: if rope_key not in self.rope_cache: self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx) video_freq = self.rope_cache[rope_key] @@ -224,7 +224,6 @@ def forward(self, video_fhw, txt_seq_lens, device): return vid_freqs, txt_freqs - @functools.lru_cache(maxsize=None) def _compute_video_freqs(self, frame, height, width, idx=0): seq_lens = frame * height * width freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index ee58c0a09..8eb2ffef2 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -19,7 +19,8 @@ import torchvision.transforms.functional as TF import torch.nn.functional as F from .distributed.fsdp import shard_model -from .modules.model import WanModel, clear_caches +from .modules.model import WanModel +from mmgp.offload import get_cache, clear_caches from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE from .modules.vae2_2 import Wan2_2_VAE @@ -496,6 +497,8 @@ def generate(self, text_len = self.model.text_len context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) + if input_video is not None: height, width = input_video.shape[-2:] + # NAG_prompt = "static, low resolution, blurry" # context_NAG = self.text_encoder([NAG_prompt], self.device)[0] # context_NAG = context_NAG.to(self.dtype) @@ -530,9 +533,10 @@ def generate(self, if image_start is None: if infinitetalk: if input_frames is not None: - image_ref = input_frames[:, -1] - if input_video is None: input_video = input_frames[:, -1:] + image_ref = input_frames[:, 0] + if input_video is None: input_video = input_frames[:, 0:1] new_shot = "Q" in video_prompt_type + denoising_strength = 0.5 else: if pre_video_frame is None: new_shot = True @@ -888,6 +892,7 @@ def clear(): latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents else: latent_noise_factor = t / 1000 + latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor if vace: overlap_noise_factor = overlap_noise / 1000 for zz in z: diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index b9c98bf81..d3dc783e3 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -12,6 +12,7 @@ import numpy as np from typing import Union,Optional from mmgp import offload +from mmgp.offload import get_cache, clear_caches from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel from ..multitalk.multitalk_utils import get_attn_map_with_target @@ -19,22 +20,6 @@ __all__ = ['WanModel'] -def get_cache(cache_name): - all_cache = offload.shared_state.get("_cache", None) - if all_cache is None: - all_cache = {} - offload.shared_state["_cache"]= all_cache - cache = offload.shared_state.get(cache_name, None) - if cache is None: - cache = {} - offload.shared_state[cache_name] = cache - return cache - -def clear_caches(): - all_cache = offload.shared_state.get("_cache", None) - if all_cache is not None: - all_cache.clear() - def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 @@ -579,19 +564,23 @@ def forward( y = self.norm_x(x) y = y.to(attention_dtype) if ref_images_count == 0: - x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) + ylist= [y] + del y + x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) else: y_shape = y.shape y = y.reshape(y_shape[0], grid_sizes[0], -1) y = y[:, ref_images_count:] y = y.reshape(y_shape[0], -1, y_shape[-1]) grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]] - y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) + ylist= [y] + y = None + y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1) x = x.reshape(y_shape[0], grid_sizes[0], -1) x[:, ref_images_count:] += y x = x.reshape(y_shape[0], -1, y_shape[-1]) - del y + del y y = self.norm2(x) diff --git a/models/wan/multitalk/attention.py b/models/wan/multitalk/attention.py index 27d488f6d..669b5c1a7 100644 --- a/models/wan/multitalk/attention.py +++ b/models/wan/multitalk/attention.py @@ -221,13 +221,16 @@ def __init__( self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: + def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: N_t, N_h, N_w = shape + x = xlist[0] + xlist.clear() x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) # get q for hidden_state B, N, C = x.shape q = self.q_linear(x) + del x q_shape = (B, N, self.num_heads, self.head_dim) q = q.view(q_shape).permute((0, 2, 1, 3)) @@ -247,9 +250,6 @@ def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=No q = rearrange(q, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K") - - attn_bias = None - # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,) qkv_list = [q, encoder_k, encoder_v] q = encoder_k = encoder_v = None x = pay_attention(qkv_list) @@ -302,7 +302,7 @@ def __init__( self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) def forward(self, - x: torch.Tensor, + xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, x_ref_attn_map=None, @@ -310,14 +310,17 @@ def forward(self, encoder_hidden_states = encoder_hidden_states.squeeze(0) if x_ref_attn_map == None: - return super().forward(x, encoder_hidden_states, shape) + return super().forward(xlist, encoder_hidden_states, shape) N_t, _, _ = shape + x = xlist[0] + xlist.clear() x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) # get q for hidden_state B, N, C = x.shape q = self.q_linear(x) + del x q_shape = (B, N, self.num_heads, self.head_dim) q = q.view(q_shape).permute((0, 2, 1, 3)) @@ -339,7 +342,9 @@ def forward(self, normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) - q = self.rope_1d(q, normalized_pos) + qlist = [q] + del q + q = self.rope_1d(qlist, normalized_pos, "q") q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) _, N_a, _ = encoder_hidden_states.shape @@ -347,7 +352,7 @@ def forward(self, encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) encoder_k, encoder_v = encoder_kv.unbind(0) - + del encoder_kv if self.qk_norm: encoder_k = self.add_k_norm(encoder_k) @@ -356,13 +361,14 @@ def forward(self, per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 encoder_pos = torch.concat([per_frame]*N_t, dim=0) encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) - encoder_k = self.rope_1d(encoder_k, encoder_pos) + enclist = [encoder_k] + del encoder_k + encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k") encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) q = rearrange(q, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K") - # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,) qkv_list = [q, encoder_k, encoder_v] q = encoder_k = encoder_v = None x = pay_attention(qkv_list) diff --git a/models/wan/multitalk/multitalk_utils.py b/models/wan/multitalk/multitalk_utils.py index 6e2b2c307..7851c452b 100644 --- a/models/wan/multitalk/multitalk_utils.py +++ b/models/wan/multitalk/multitalk_utils.py @@ -16,7 +16,7 @@ import binascii import os.path as osp from skimage import color - +from mmgp.offload import get_cache, clear_caches VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") ASPECT_RATIO_627 = { @@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): # @torch.compile -def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None): - - ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) +def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None): + dtype = visual_q.dtype + ref_k = ref_k.to(dtype).to(visual_q.device) scale = 1.0 / visual_q.shape[-1] ** 0.5 visual_q = visual_q * scale visual_q = visual_q.transpose(1, 2) ref_k = ref_k.transpose(1, 2) - attn = visual_q @ ref_k.transpose(-2, -1) + visual_q_shape = visual_q.shape + visual_q = visual_q.view(-1, visual_q_shape[-1] ) + number_chunks = visual_q_shape[-2]*ref_k.shape[-2] / 53090100 * 2 + chunk_size = int(visual_q_shape[-2] / number_chunks) + chunks =torch.split(visual_q, chunk_size) + maps_lists = [ [] for _ in ref_target_masks] + for q_chunk in chunks: + attn = q_chunk @ ref_k.transpose(-2, -1) + x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + del attn + ref_target_masks = ref_target_masks.to(dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(dtype) + + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask[None, None, None, ...] + x_ref_attnmap = x_ref_attn_map_source * ref_target_mask + x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens + maps_lists[class_idx].append(x_ref_attnmap) + + del x_ref_attn_map_source - if attn_bias is not None: attn += attn_bias + x_ref_attn_maps = [] + for class_idx, maps_list in enumerate(maps_lists): + attn_map_fuse = torch.concat(maps_list, dim= -1) + attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1) + x_ref_attn_maps.append( attn_map_fuse ) - x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + return torch.concat(x_ref_attn_maps, dim=0) + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count): + dtype = visual_q.dtype + ref_k = ref_k.to(dtype).to(visual_q.device) + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q * scale + visual_q = visual_q.transpose(1, 2) + ref_k = ref_k.transpose(1, 2) + attn = visual_q @ ref_k.transpose(-2, -1) + + x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + del attn x_ref_attn_maps = [] - ref_target_masks = ref_target_masks.to(visual_q.dtype) - x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) + ref_target_masks = ref_target_masks.to(dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(dtype) for class_idx, ref_target_mask in enumerate(ref_target_masks): ref_target_mask = ref_target_mask[None, None, None, ...] x_ref_attnmap = x_ref_attn_map_source * ref_target_mask x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H - - if mode == 'mean': - x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens - elif mode == 'max': - x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens - + x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens (mean of heads) x_ref_attn_maps.append(x_ref_attnmap) - del attn del x_ref_attn_map_source return torch.concat(x_ref_attn_maps, dim=0) - def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0): """Args: query (torch.tensor): B M H K @@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli N_t, N_h, N_w = shape x_seqlens = N_h * N_w + if x_seqlens <= 1508: + split_num = 10 # 540p + else: + split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p + ref_k = ref_k[:, :x_seqlens] if ref_images_count > 0 : visual_q_shape = visual_q.shape @@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli split_chunk = heads // split_num - for i in range(split_num): - x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) - x_ref_attn_maps += x_ref_attn_maps_perhead + if split_chunk == 1: + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead + else: + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead x_ref_attn_maps /= split_num return x_ref_attn_maps @@ -158,7 +196,6 @@ def __init__(self, self.base = 10000 - @lru_cache(maxsize=32) def precompute_freqs_cis_1d(self, pos_indices): freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) @@ -167,7 +204,7 @@ def precompute_freqs_cis_1d(self, pos_indices): freqs = repeat(freqs, "... n -> ... (n r)", r=2) return freqs - def forward(self, x, pos_indices): + def forward(self, qlist, pos_indices, cache_entry = None): """1D RoPE. Args: @@ -176,16 +213,26 @@ def forward(self, x, pos_indices): Returns: query with the same shape as input. """ - freqs_cis = self.precompute_freqs_cis_1d(pos_indices) - - x_ = x.float() - - freqs_cis = freqs_cis.float().to(x.device) - cos, sin = freqs_cis.cos(), freqs_cis.sin() - cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') - x_ = (x_ * cos) + (rotate_half(x_) * sin) - - return x_.type_as(x) + xq= qlist[0] + qlist.clear() + cache = get_cache("multitalk_rope") + freqs_cis= cache.get(cache_entry, None) + if freqs_cis is None: + freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices) + cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0) + # cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + # real * cos - imag * sin + # imag * cos + real * sin + xq_dtype = xq.dtype + xq_out = xq.to(torch.float) + xq = None + xq_rot = rotate_half(xq_out) + xq_out *= cos + xq_rot *= sin + xq_out += xq_rot + del xq_rot + xq_out = xq_out.to(xq_dtype) + return xq_out diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index c3ad01294..4d59aa354 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -1,5 +1,6 @@ import torch import numpy as np +import gradio as gr def test_class_i2v(base_model_type): return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] @@ -116,6 +117,11 @@ def query_model_def(base_model_type, model_def): extra_model_def["no_background_removal"] = True # extra_model_def["at_least_one_image_ref_needed"] = True + + if base_model_type in ["phantom_1.3B", "phantom_14B"]: + extra_model_def["one_image_ref_needed"] = True + + return extra_model_def @staticmethod @@ -235,6 +241,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): if "I" in video_prompt_type: video_prompt_type = video_prompt_type.replace("KI", "QKI") ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.28: + if base_model_type in "infinitetalk": + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "U" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("U", "RU") + ui_defaults["video_prompt_type"] = video_prompt_type + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ @@ -309,3 +323,11 @@ def validate_generative_settings(base_model_type, model_def, inputs): if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None: video_prompt_type = video_prompt_type.replace("I", "").replace("K","") inputs["video_prompt_type"] = video_prompt_type + + + if base_model_type in ["vace_standin_14B"]: + image_refs = inputs["image_refs"] + video_prompt_type = inputs["video_prompt_type"] + if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type: + gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.") + diff --git a/requirements.txt b/requirements.txt index a2e1fde85..7704d6db9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,8 @@ soundfile mutagen pyloudnorm librosa==0.11.0 - +speechbrain==1.0.3 + # UI & interaction gradio==5.23.0 dashscope @@ -43,7 +44,7 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.5.10 +mmgp==3.5.11 peft==0.15.0 matplotlib diff --git a/shared/attention.py b/shared/attention.py index a95332dc5..cc6ece02d 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -3,6 +3,7 @@ from importlib.metadata import version from mmgp import offload import torch.nn.functional as F +import warnings major, minor = torch.cuda.get_device_capability(None) bfloat16_supported = major >= 8 @@ -42,34 +43,51 @@ def sageattn_varlen_wrapper( sageattn_varlen_wrapper = None -import warnings - try: - from sageattention import sageattn - from .sage2_core import sageattn as alt_sageattn, is_sage2_supported + from .sage2_core import sageattn as sageattn2, is_sage2_supported sage2_supported = is_sage2_supported() except ImportError: - sageattn = None - alt_sageattn = None + sageattn2 = None sage2_supported = False -# @torch.compiler.disable() -def sageattn_wrapper( +@torch.compiler.disable() +def sageattn2_wrapper( qkv_list, attention_length ): q,k, v = qkv_list - if True: - qkv_list = [q,k,v] - del q, k ,v - o = alt_sageattn(qkv_list, tensor_layout="NHD") - else: - o = sageattn(q, k, v, tensor_layout="NHD") - del q, k ,v + qkv_list = [q,k,v] + del q, k ,v + o = sageattn2(qkv_list, tensor_layout="NHD") + qkv_list.clear() + + return o +try: + from sageattn import sageattn_blackwell as sageattn3 +except ImportError: + sageattn3 = None + +@torch.compiler.disable() +def sageattn3_wrapper( + qkv_list, + attention_length + ): + q,k, v = qkv_list + # qkv_list = [q,k,v] + # del q, k ,v + # o = sageattn3(qkv_list, tensor_layout="NHD") + q = q.transpose(1,2) + k = k.transpose(1,2) + v = v.transpose(1,2) + o = sageattn3(q, k, v) + o = o.transpose(1,2) qkv_list.clear() return o + + + # try: # if True: # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda @@ -94,7 +112,7 @@ def sageattn_wrapper( # return o # except ImportError: -# sageattn = sageattn_qk_int8_pv_fp8_window_cuda +# sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda @torch.compiler.disable() def sdpa_wrapper( @@ -124,21 +142,28 @@ def get_attention_modes(): ret.append("xformers") if sageattn_varlen_wrapper != None: ret.append("sage") - if sageattn != None and version("sageattention").startswith("2") : + if sageattn2 != None and version("sageattention").startswith("2") : ret.append("sage2") + if sageattn3 != None: # and version("sageattention").startswith("3") : + ret.append("sage3") return ret def get_supported_attention_modes(): ret = get_attention_modes() + major, minor = torch.cuda.get_device_capability() + if major < 10: + if "sage3" in ret: + ret.remove("sage3") + if not sage2_supported: if "sage2" in ret: ret.remove("sage2") - major, minor = torch.cuda.get_device_capability() if major < 7: if "sage" in ret: ret.remove("sage") + return ret __all__ = [ @@ -201,7 +226,7 @@ def pay_attention( from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG - if b > 1 and k_lens != None and attn in ("sage2", "sdpa"): + if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"): assert attention_mask == None # Poor's man var k len attention assert q_lens == None @@ -234,7 +259,7 @@ def pay_attention( q_chunks, k_chunks, v_chunks = None, None, None o = torch.cat(o, dim = 0) return o - elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"): + elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"): assert b == 1 szq = q_lens[0].item() if q_lens != None else lq szk = k_lens[0].item() if k_lens != None else lk @@ -284,13 +309,19 @@ def pay_attention( max_seqlen_q=lq, max_seqlen_kv=lk, ).unflatten(0, (b, lq)) + elif attn=="sage3": + import math + if cross_attn or True: + qkv_list = [q,k,v] + del q,k,v + x = sageattn3_wrapper(qkv_list, lq) elif attn=="sage2": import math if cross_attn or True: qkv_list = [q,k,v] del q,k,v - x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0) + x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0) # else: # layer = offload.shared_state["layer"] # embed_sizes = offload.shared_state["embed_sizes"] diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py new file mode 100644 index 000000000..6bd86a8d3 --- /dev/null +++ b/shared/gradio/gallery.py @@ -0,0 +1,496 @@ +from __future__ import annotations +import os, io, tempfile, mimetypes +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal + +import gradio as gr +import PIL +from PIL import Image as PILImage + +FilePath = str +ImageLike = Union["PIL.Image.Image", Any] + +IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"} +VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"} + +def get_state(state): + return state if isinstance(state, dict) else state.value + +def get_list( objs): + if objs is None: + return [] + return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs] + +class AdvancedMediaGallery: + def __init__( + self, + label: str = "Media", + *, + media_mode: Literal["image", "video"] = "image", + height = None, + columns: Union[int, Tuple[int, ...]] = 6, + show_label: bool = True, + initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None, + elem_id: Optional[str] = None, + elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",), + accept_filter: bool = True, # restrict Add-button dialog to allowed extensions + single_image_mode: bool = False, # start in single-image mode (Add replaces) + ): + assert media_mode in ("image", "video") + self.label = label + self.media_mode = media_mode + self.height = height + self.columns = columns + self.show_label = show_label + self.elem_id = elem_id + self.elem_classes = list(elem_classes) if elem_classes else None + self.accept_filter = accept_filter + + items = self._normalize_initial(initial or [], media_mode) + + # Components (filled on mount) + self.container: Optional[gr.Column] = None + self.gallery: Optional[gr.Gallery] = None + self.upload_btn: Optional[gr.UploadButton] = None + self.btn_remove: Optional[gr.Button] = None + self.btn_left: Optional[gr.Button] = None + self.btn_right: Optional[gr.Button] = None + self.btn_clear: Optional[gr.Button] = None + + # Single dict state + self.state: Optional[gr.State] = None + self._initial_state: Dict[str, Any] = { + "items": items, + "selected": (len(items) - 1) if items else None, + "single": bool(single_image_mode), + "mode": self.media_mode, + } + + # ---------------- helpers ---------------- + + def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]: + out: List[Any] = [] + if mode == "image": + for it in items: + p = self._ensure_image_item(it) + if p is not None: + out.append(p) + else: + for it in items: + if isinstance(item, tuple): item = item[0] + if isinstance(it, str) and self._is_video_path(it): + out.append(os.path.abspath(it)) + return out + + def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]: + # Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path + if isinstance(item, tuple): item = item[0] + if isinstance(item, str): + return os.path.abspath(item) if self._is_image_path(item) else None + if PILImage is None: + return None + try: + if isinstance(item, PILImage.Image): + img = item + else: + import numpy as np # type: ignore + if isinstance(item, np.ndarray): + img = PILImage.fromarray(item) + elif hasattr(item, "read"): + data = item.read() + img = PILImage.open(io.BytesIO(data)).convert("RGBA") + else: + return None + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + img.save(tmp.name) + return tmp.name + except Exception: + return None + + @staticmethod + def _extract_path(obj: Any) -> Optional[str]: + # Try to get a filesystem path (for mode filtering); otherwise None. + if isinstance(obj, str): + return obj + try: + import pathlib + if isinstance(obj, pathlib.Path): # type: ignore + return str(obj) + except Exception: + pass + if isinstance(obj, dict): + return obj.get("path") or obj.get("name") + for attr in ("path", "name"): + if hasattr(obj, attr): + try: + val = getattr(obj, attr) + if isinstance(val, str): + return val + except Exception: + pass + return None + + @staticmethod + def _is_image_path(p: str) -> bool: + ext = os.path.splitext(p)[1].lower() + if ext in IMAGE_EXTS: + return True + mt, _ = mimetypes.guess_type(p) + return bool(mt and mt.startswith("image/")) + + @staticmethod + def _is_video_path(p: str) -> bool: + ext = os.path.splitext(p)[1].lower() + if ext in VIDEO_EXTS: + return True + mt, _ = mimetypes.guess_type(p) + return bool(mt and mt.startswith("video/")) + + def _filter_items_by_mode(self, items: List[Any]) -> List[Any]: + # Enforce image-only or video-only collection regardless of how files were added. + out: List[Any] = [] + if self.media_mode == "image": + for it in items: + p = self._extract_path(it) + if p is None: + # No path: likely an image object added programmatically => keep + out.append(it) + elif self._is_image_path(p): + out.append(os.path.abspath(p)) + else: + for it in items: + p = self._extract_path(it) + if p is not None and self._is_video_path(p): + out.append(os.path.abspath(p)) + return out + + @staticmethod + def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]: + # Keep it simple: dedupe by path when available, else allow duplicates. + seen_paths = set() + def key(x: Any) -> Optional[str]: + if isinstance(x, str): return os.path.abspath(x) + try: + import pathlib + if isinstance(x, pathlib.Path): # type: ignore + return os.path.abspath(str(x)) + except Exception: + pass + if isinstance(x, dict): + p = x.get("path") or x.get("name") + return os.path.abspath(p) if isinstance(p, str) else None + for attr in ("path", "name"): + if hasattr(x, attr): + try: + v = getattr(x, attr) + return os.path.abspath(v) if isinstance(v, str) else None + except Exception: + pass + return None + + out: List[Any] = [] + for lst in (cur, add): + for it in lst: + k = key(it) + if k is None or k not in seen_paths: + out.append(it) + if k is not None: + seen_paths.add(k) + return out + + @staticmethod + def _paths_from_payload(payload: Any) -> List[Any]: + # Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly. + if payload is None: + return [] + if isinstance(payload, (list, tuple, set)): + return list(payload) + return [payload] + + # ---------------- event handlers ---------------- + + def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : + # Mirror the selected index into state and the gallery (server-side selected_index) + idx = None + if evt is not None and hasattr(evt, "index"): + ix = evt.index + if isinstance(ix, int): + idx = ix + elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int): + if isinstance(self.columns, int) and len(ix) >= 2: + idx = ix[0] * max(1, int(self.columns)) + ix[1] + else: + idx = ix[0] + st = get_state(state) + n = len(get_list(gallery)) + sel = idx if (idx is not None and 0 <= idx < n) else None + st["selected"] = sel + return gr.update(selected_index=sel), st + + def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : + # Fires when users add/drag/drop/delete via the Gallery itself. + items_filtered = self._filter_items_by_mode(list(value or [])) + st = get_state(state) + st["items"] = items_filtered + # Keep selection if still valid, else default to last + old_sel = st.get("selected", None) + if old_sel is None or not (0 <= old_sel < len(items_filtered)): + new_sel = (len(items_filtered) - 1) if items_filtered else None + else: + new_sel = old_sel + st["selected"] = new_sel + return gr.update(value=items_filtered, selected_index=new_sel), st + + + def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): + """ + Insert added items right AFTER the currently selected index. + Keeps the same ordering as chosen in the file picker, dedupes by path, + and re-selects the last inserted item. + """ + # New items (respect image/video mode) + new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload)) + + st = get_state(state) + cur: List[Any] = get_list(gallery) + sel = st.get("selected", None) + if sel is None: + sel = (len(cur) -1) if len(cur)>0 else 0 + single = bool(st.get("single", False)) + + # Nothing to add: keep as-is + if not new_items: + return gr.update(value=cur, selected_index=st.get("selected")), st + + # Single-image mode: replace + if single: + st["items"] = [new_items[-1]] + st["selected"] = 0 + return gr.update(value=st["items"], selected_index=0), st + + # ---------- helpers ---------- + def key_of(it: Any) -> Optional[str]: + # Prefer class helper if present + if hasattr(self, "_extract_path"): + p = self._extract_path(it) # type: ignore + else: + p = it if isinstance(it, str) else None + if p is None and isinstance(it, dict): + p = it.get("path") or it.get("name") + if p is None and hasattr(it, "path"): + try: p = getattr(it, "path") + except Exception: p = None + if p is None and hasattr(it, "name"): + try: p = getattr(it, "name") + except Exception: p = None + return os.path.abspath(p) if isinstance(p, str) else None + + # Dedupe the incoming batch by path, preserve order + seen_new = set() + incoming: List[Any] = [] + for it in new_items: + k = key_of(it) + if k is None or k not in seen_new: + incoming.append(it) + if k is not None: + seen_new.add(k) + + # Remove any existing occurrences of the incoming items from current list, + # BUT keep the currently selected item even if it's also in incoming. + cur_clean: List[Any] = [] + # sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None + # for idx, it in enumerate(cur): + # k = key_of(it) + # if it is sel_item: + # cur_clean.append(it) + # continue + # if k is not None and k in seen_new: + # continue # drop duplicate; we'll reinsert at the target spot + # cur_clean.append(it) + + # # Compute insertion position: right AFTER the (possibly shifted) selected item + # if sel_item is not None: + # # find sel_item's new index in cur_clean + # try: + # pos_sel = cur_clean.index(sel_item) + # except ValueError: + # # Shouldn't happen, but fall back to end + # pos_sel = len(cur_clean) - 1 + # insert_pos = pos_sel + 1 + # else: + # insert_pos = len(cur_clean) # no selection -> append at end + insert_pos = min(sel, len(cur) -1) + cur_clean = cur + # Build final list and selection + merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:] + new_sel = insert_pos + len(incoming) # select the last inserted item + + st["items"] = merged + st["selected"] = new_sel + return gr.update(value=merged, selected_index=new_sel), st + + def _on_remove(self, state: Dict[str, Any], gallery) : + st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) + if sel is None or not (0 <= sel < len(items)): + return gr.update(value=items, selected_index=st.get("selected")), st + items.pop(sel) + if not items: + st["items"] = []; st["selected"] = None + return gr.update(value=[], selected_index=None), st + new_sel = min(sel, len(items) - 1) + st["items"] = items; st["selected"] = new_sel + return gr.update(value=items, selected_index=new_sel), st + + def _on_move(self, delta: int, state: Dict[str, Any], gallery) : + st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) + if sel is None or not (0 <= sel < len(items)): + return gr.update(value=items, selected_index=sel), st + j = sel + delta + if j < 0 or j >= len(items): + return gr.update(value=items, selected_index=sel), st + items[sel], items[j] = items[j], items[sel] + st["items"] = items; st["selected"] = j + return gr.update(value=items, selected_index=j), st + + def _on_clear(self, state: Dict[str, Any]) : + st = {"items": [], "selected": None, "single": state.get("single", False), "mode": self.media_mode} + return gr.update(value=[], selected_index=None), st + + def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : + st = get_state(state); st["single"] = bool(to_single) + items: List[Any] = list(st["items"]); sel = st.get("selected", None) + if st["single"]: + keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None) + items = [keep] if keep is not None else [] + sel = 0 if items else None + st["items"] = items; st["selected"] = sel + + upload_update = gr.update(file_count=("single" if st["single"] else "multiple")) + left_update = gr.update(visible=not st["single"]) + right_update = gr.update(visible=not st["single"]) + clear_update = gr.update(visible=not st["single"]) + gallery_update= gr.update(value=items, selected_index=sel) + + return upload_update, left_update, right_update, clear_update, gallery_update, st + + # ---------------- build & wire ---------------- + + def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False): + if parent is not None: + with parent: + col = self._build_ui() + else: + col = self._build_ui() + if not update_form: + self._wire_events() + return col + + def _build_ui(self) -> gr.Column: + with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col: + self.container = col + + self.state = gr.State(dict(self._initial_state)) + + self.gallery = gr.Gallery( + label=self.label, + value=self._initial_state["items"], + height=self.height, + columns=self.columns, + show_label=self.show_label, + preview= True, + selected_index=self._initial_state["selected"], # server-side selection + ) + + # One-line controls + exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None + with gr.Row(equal_height=True, elem_classes=["amg-controls"]): + self.upload_btn = gr.UploadButton( + "Set" if self._initial_state["single"] else "Add", + file_types=exts, + file_count=("single" if self._initial_state["single"] else "multiple"), + variant="primary", + size="sm", + min_width=1, + ) + self.btn_remove = gr.Button("Remove", size="sm", min_width=1) + self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1) + self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1) + self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1) + + return col + + def _wire_events(self): + # Selection: mirror into state and keep gallery.selected_index in sync + self.gallery.select( + self._on_select, + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + ) + + # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) + self.gallery.change( + self._on_gallery_change, + inputs=[self.gallery, self.state], + outputs=[self.gallery, self.state], + ) + + # Add via UploadButton + self.upload_btn.upload( + self._on_add, + inputs=[self.upload_btn, self.state, self.gallery], + outputs=[self.gallery, self.state], + ) + + # Remove selected + self.btn_remove.click( + self._on_remove, + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + ) + + # Reorder using selected index, keep same item selected + self.btn_left.click( + lambda st, gallery: self._on_move(-1, st, gallery), + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + ) + self.btn_right.click( + lambda st, gallery: self._on_move(+1, st, gallery), + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + ) + + # Clear all + self.btn_clear.click( + self._on_clear, + inputs=[self.state], + outputs=[self.gallery, self.state], + ) + + # ---------------- public API ---------------- + + def set_one_image_mode(self, enabled: bool = True): + """Toggle single-image mode at runtime.""" + return ( + self._on_toggle_single, + [gr.State(enabled), self.state], + [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state], + ) + + def get_toggable_elements(self): + return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state] + +# import gradio as gr + +# with gr.Blocks() as demo: +# amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8) +# amg.mount() +# g = amg.gallery +# # buttons to switch modes live (optional) +# def process(g): +# pass +# with gr.Row(): +# gr.Button("toto").click(process, g) +# gr.Button("ONE image").click(*amg.set_one_image_mode(True)) +# gr.Button("MULTI image").click(*amg.set_one_image_mode(False)) + +# demo.launch() diff --git a/shared/utils/qwen_vl_utils.py b/shared/utils/qwen_vl_utils.py index 3c682e6ad..3fe02cc1a 100644 --- a/shared/utils/qwen_vl_utils.py +++ b/shared/utils/qwen_vl_utils.py @@ -9,7 +9,6 @@ import sys import time import warnings -from functools import lru_cache from io import BytesIO import requests @@ -257,7 +256,6 @@ def _read_video_decord(ele: dict,) -> torch.Tensor: FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) -@lru_cache(maxsize=1) def get_video_reader_backend() -> str: if FORCE_QWENVL_VIDEO_READER is not None: video_reader_backend = FORCE_QWENVL_VIDEO_READER diff --git a/wgp.py b/wgp.py index a151252ba..c757ad87f 100644 --- a/wgp.py +++ b/wgp.py @@ -49,6 +49,7 @@ from preprocessing.matanyone import app as matanyone_app from tqdm import tqdm import requests +from shared.gradio.gallery import AdvancedMediaGallery # import torch._dynamo as dynamo # dynamo.config.recompile_limit = 2000 # default is 256 @@ -58,9 +59,9 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.10" -WanGP_version = "8.21" -settings_version = 2.27 +target_mmgp_version = "3.5.11" +WanGP_version = "8.3" +settings_version = 2.28 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -466,7 +467,7 @@ def process_prompt_and_add_tasks(state, model_choice): image_mask = None if "G" in video_prompt_type: - gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ") + gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ") else: denoising_strength = 1.0 if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]: @@ -4733,46 +4734,61 @@ def remove_temp_filenames(temp_filenames_list): # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) + if video_guide is not None: keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] - if infinitetalk and video_guide is not None: - src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True) - new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size) - src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - refresh_preview["video_guide"] = src_image - src_video = convert_image_to_tensor(src_image).unsqueeze(1) - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - if ltxv and video_guide is not None: - preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") - status_info = "Extracting " + processes_names[preprocess_type] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - # start one frame ealier to facilitate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) - if src_video != None: - src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - src_video = src_video.permute(3, 0, 1, 2) - src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - if t2v and "G" in video_prompt_type: - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) - if video_guide_processed == None: - src_video = pre_video_guide - else: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] + if ltxv: + preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") + status_info = "Extracting " + processes_names[preprocess_type] + send_cmd("progress", [0, get_latest_status(state, status_info)]) + # start one frame ealier to facilitate latents merging later + src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) + if src_video != None: + src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] + refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) + src_video = src_video.permute(3, 0, 1, 2) + src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w + if sample_fit_canvas != None: + image_size = src_video.shape[-2:] + sample_fit_canvas = None + + elif hunyuan_custom_edit: + if "P" in video_prompt_type: + progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] + else: + progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] + + send_cmd("progress", progress_args) + src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) + refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) + if src_mask != None: + refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) + + elif "R" in video_prompt_type: # sparse video to video + src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True) + new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size) + src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + refresh_preview["video_guide"] = src_image + src_video = convert_image_to_tensor(src_image).unsqueeze(1) + if sample_fit_canvas != None: + image_size = src_video.shape[-2:] sample_fit_canvas = None - src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) - if pre_video_guide != None: - src_video = torch.cat( [pre_video_guide, src_video], dim=1) + + elif "G" in video_prompt_type: # video to video + video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) + if video_guide_processed is None: + src_video = pre_video_guide + else: + if sample_fit_canvas != None: + image_size = video_guide_processed.shape[-3: -1] + sample_fit_canvas = None + src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) + if pre_video_guide != None: + src_video = torch.cat( [pre_video_guide, src_video], dim=1) if vace : image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications @@ -4834,17 +4850,8 @@ def remove_temp_filenames(temp_filenames_list): if sample_fit_canvas != None: image_size = src_video[0].shape[-2:] sample_fit_canvas = None - elif hunyuan_custom_edit: - if "P" in video_prompt_type: - progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] - else: - progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] - send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - if src_mask != None: - refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) + if len(refresh_preview) > 0: new_inputs= locals() new_inputs.update(refresh_preview) @@ -6013,10 +6020,16 @@ def video_to_source_video(state, input_file_list, choice): def image_to_ref_image_add(state, input_file_list, choice, target, target_name): file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() - gr.Info(f"Selected Image was added to {target_name}") - if target == None: - target =[] - target.append( file_list[choice]) + model_type = state["model_type"] + model_def = get_model_def(model_type) + if model_def.get("one_image_ref_needed", False): + gr.Info(f"Selected Image was set to {target_name}") + target =[file_list[choice]] + else: + gr.Info(f"Selected Image was added to {target_name}") + if target == None: + target =[] + target.append( file_list[choice]) return target def image_to_ref_image_set(state, input_file_list, choice, target, target_name): @@ -6229,6 +6242,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw if not "WanGP" in configs.get("type", ""): configs = None except: configs = None + if configs is None: return None, False current_model_filename = state["model_filename"] @@ -6615,11 +6629,12 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible) def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt): - video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI") + video_prompt_type = del_in_sequence(video_prompt_type, "RGUVQKI") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) control_video_visible = "V" in video_prompt_type ref_images_visible = "I" in video_prompt_type - return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ) + denoising_strength_visible = "G" in video_prompt_type + return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible ) # def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): # video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] @@ -6996,6 +7011,12 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non v2i_switch_supported = (vace or t2v or standin) and not image_outputs ti2v_2_2 = base_model_type in ["ti2v_2_2"] + def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ): + with gr.Row(visible = visible) as gallery_row: + gallery_amg = AdvancedMediaGallery(media_mode="image", height=None, columns=4, label=label, initial = value , single_image_mode = single_image_mode ) + gallery_amg.mount(update_form=update_form) + return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements() + image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 ) if not v2i_switch_supported and not image_outputs: image_mode_value = 0 @@ -7009,15 +7030,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"): pass - with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column: if vace or infinitetalk: image_prompt_type_value= ui_defaults.get("image_prompt_type","") image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3) - image_start = gr.Gallery(visible = False) - image_end = gr.Gallery(visible = False) + image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) + image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) model_mode = gr.Dropdown(visible = False) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) @@ -7034,13 +7054,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type_choices += [("Continue Video", "V")] image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) - # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - image_start = gr.Gallery(preview= True, - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - image_end = gr.Gallery(preview= True, - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) + image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value ) video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) if not diffusion_forcing: model_mode = gr.Dropdown( @@ -7061,8 +7076,8 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" ) elif recammaster: image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V") - image_start = gr.Gallery(value = None, visible = False) - image_end = gr.Gallery(value = None, visible= False) + image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) + image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),) model_mode = gr.Dropdown( choices=[ @@ -7095,21 +7110,16 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) any_start_image = True any_end_image = True - image_start = gr.Gallery(preview= True, - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - - image_end = gr.Gallery(preview= True, - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) + image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) + image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value ) if hunyuan_i2v: video_source = gr.Video(value=None, visible=False) else: video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) else: image_prompt_type = gr.Radio(choices=[("", "")], value="") - image_start = gr.Gallery(value=None) - image_end = gr.Gallery(value=None) + image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) + image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) video_source = gr.Video(value=None, visible=False) model_mode = gr.Dropdown(value=None, visible=False) keep_frames_video_source = gr.Text(visible=False) @@ -7184,12 +7194,14 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if infinitetalk: video_prompt_type_video_guide_alt = gr.Dropdown( choices=[ - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"), - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"), + ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), + ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"), + ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"), + ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"), + ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"), ], - value=filter_letters(video_prompt_type_value, "UVQKI"), + value=filter_letters(video_prompt_type_value, "RGUVQKI"), label="Video to Video", scale = 3, visible= True, show_label= False, ) else: @@ -7318,11 +7330,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image - image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images" + (" (each Image will start a new Clip)" if infinitetalk else ""), - type ="pil", show_label= True, - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, - value= ui_defaults.get("image_refs", None), - ) + + image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) + image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "") + image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) @@ -7935,7 +7946,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, - min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] # presets_column, + min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] + image_start_extra + image_end_extra + image_refs_extra # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -7953,11 +7964,11 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) - image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) + image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start_row, image_end_row, video_source, keep_frames_video_source] ) # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) - video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col]) + video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col]) video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) - video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ]) + video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ]) video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt]) @@ -8348,6 +8359,7 @@ def check(mode): ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), + ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"), ], value= attention_mode, label="Attention Type", @@ -8663,7 +8675,7 @@ def generate_about_tab(): gr.Markdown("- Blackforest Labs for the innovative Flux image generators (https://github.com/black-forest-labs/flux)") gr.Markdown("- Alibaba Qwen Team for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)") gr.Markdown("- Lightricks for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)") - gr.Markdown("- Hugging Face for the providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") + gr.Markdown("- Hugging Face for providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") gr.Markdown("
Huge acknowledgments to these great open source projects used in WanGP:") gr.Markdown("- Rife: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)") gr.Markdown("- DwPose: Open Pose extractor (https://github.com/IDEA-Research/DWPose)") @@ -8672,7 +8684,7 @@ def generate_about_tab(): gr.Markdown("- Pyannote: speaker diarization (https://github.com/pyannote/pyannote-audio)") gr.Markdown("
Special thanks to the following people for their support:") - gr.Markdown("- Cocktail Peanuts : QA and simple installation via Pinokio.computer") + gr.Markdown("- Cocktail Peanuts : QA dpand simple installation via Pinokio.computer") gr.Markdown("- Tophness : created (former) multi tabs and queuing frameworks") gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance") gr.Markdown("- Remade_AI : for their awesome Loras collection") From f2db023a3d314023736d2a5d5cd369b51ff4febd Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 2 Sep 2025 21:03:16 +0200 Subject: [PATCH 007/155] a release without a gaming breaking bug is not a release --- shared/gradio/gallery.py | 1 + 1 file changed, 1 insertion(+) diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py index 6bd86a8d3..ef53d4d86 100644 --- a/shared/gradio/gallery.py +++ b/shared/gradio/gallery.py @@ -397,6 +397,7 @@ def _build_ui(self) -> gr.Column: columns=self.columns, show_label=self.show_label, preview= True, + type="pil", selected_index=self._initial_state["selected"], # server-side selection ) From 959ab9e0c13a31cb8973a43c1d7fa96149901319 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 2 Sep 2025 23:13:02 +0200 Subject: [PATCH 008/155] no more pain --- README.md | 3 ++- shared/gradio/gallery.py | 20 +++++++++++++------- wgp.py | 30 ++++++++++++++++++------------ 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/README.md b/README.md index be76cbcc8..eef872837 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### September 2 2025: WanGP v8.3 - At last the pain stops +### September 2 2025: WanGP v8.31 - At last the pain stops - This single new feature should give you the strength to face all the potential bugs of this new release: **Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.** @@ -30,6 +30,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models - **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash. +*update 8.31: one shouldnt talk about bugs if one dont want to attract bugs* ### August 29 2025: WanGP v8.21 - Here Goes Your Weekend diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py index ef53d4d86..2f30ef16e 100644 --- a/shared/gradio/gallery.py +++ b/shared/gradio/gallery.py @@ -224,7 +224,9 @@ def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : n = len(get_list(gallery)) sel = idx if (idx is not None and 0 <= idx < n) else None st["selected"] = sel - return gr.update(selected_index=sel), st + # return gr.update(selected_index=sel), st + # return gr.update(), st + return st def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : # Fires when users add/drag/drop/delete via the Gallery itself. @@ -238,8 +240,10 @@ def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : else: new_sel = old_sel st["selected"] = new_sel - return gr.update(value=items_filtered, selected_index=new_sel), st + # return gr.update(value=items_filtered, selected_index=new_sel), st + # return gr.update(value=items_filtered), st + return gr.update(), st def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): """ @@ -338,7 +342,8 @@ def _on_remove(self, state: Dict[str, Any], gallery) : return gr.update(value=[], selected_index=None), st new_sel = min(sel, len(items) - 1) st["items"] = items; st["selected"] = new_sel - return gr.update(value=items, selected_index=new_sel), st + # return gr.update(value=items, selected_index=new_sel), st + return gr.update(value=items), st def _on_move(self, delta: int, state: Dict[str, Any], gallery) : st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) @@ -352,8 +357,8 @@ def _on_move(self, delta: int, state: Dict[str, Any], gallery) : return gr.update(value=items, selected_index=j), st def _on_clear(self, state: Dict[str, Any]) : - st = {"items": [], "selected": None, "single": state.get("single", False), "mode": self.media_mode} - return gr.update(value=[], selected_index=None), st + st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode} + return gr.update(value=[], selected_index=0), st def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : st = get_state(state); st["single"] = bool(to_single) @@ -397,7 +402,8 @@ def _build_ui(self) -> gr.Column: columns=self.columns, show_label=self.show_label, preview= True, - type="pil", + # type="pil", + file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), selected_index=self._initial_state["selected"], # server-side selection ) @@ -424,7 +430,7 @@ def _wire_events(self): self.gallery.select( self._on_select, inputs=[self.state, self.gallery], - outputs=[self.gallery, self.state], + outputs=[self.state], ) # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) diff --git a/wgp.py b/wgp.py index ec51a74ff..6527357b5 100644 --- a/wgp.py +++ b/wgp.py @@ -60,7 +60,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.11" -WanGP_version = "8.3" +WanGP_version = "8.31" settings_version = 2.28 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -186,6 +186,17 @@ def compute_sliding_window_no(current_video_length, sliding_window_size, discard return 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames)) +def clean_image_list(gradio_list): + if not isinstance(gradio_list, list): gradio_list = [gradio_list] + gradio_list = [ tup[0] if isinstance(tup, tuple) else tup for tup in gradio_list ] + + if any( not isinstance(image, (Image.Image, str)) for image in gradio_list): return None + if any( isinstance(image, str) and not has_image_file_extension(image) for image in gradio_list): return None + gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ] + return gradio_list + + + def process_prompt_and_add_tasks(state, model_choice): if state.get("validate_success",0) != 1: @@ -436,11 +447,10 @@ def process_prompt_and_add_tasks(state, model_choice): if image_refs == None or len(image_refs) == 0: gr.Info("You must provide at least one Refererence Image") return - if any(isinstance(image[0], str) for image in image_refs) : + image_refs = clean_image_list(image_refs) + if image_refs == None : gr.Info("A Reference Image should be an Image") return - if isinstance(image_refs, list): - image_refs = [ convert_image(tup[0]) for tup in image_refs ] else: image_refs = None @@ -497,12 +507,10 @@ def process_prompt_and_add_tasks(state, model_choice): if image_start == None or isinstance(image_start, list) and len(image_start) == 0: gr.Info("You must provide a Start Image") return - if not isinstance(image_start, list): - image_start = [image_start] - if not all( not isinstance(img[0], str) for img in image_start) : + image_start = clean_image_list(image_start) + if image_start == None : gr.Info("Start Image should be an Image") return - image_start = [ convert_image(tup[0]) for tup in image_start ] else: image_start = None @@ -510,15 +518,13 @@ def process_prompt_and_add_tasks(state, model_choice): if image_end == None or isinstance(image_end, list) and len(image_end) == 0: gr.Info("You must provide an End Image") return - if not isinstance(image_end, list): - image_end = [image_end] - if not all( not isinstance(img[0], str) for img in image_end) : + image_end = clean_image_list(image_end) + if image_end == None : gr.Info("End Image should be an Image") return if len(image_start) != len(image_end): gr.Info("The number of Start and End Images should be the same ") return - image_end = [ convert_image(tup[0]) for tup in image_end ] else: image_end = None From 2ac31cc12a825ca3c6ea15116f2b79d7115914e8 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 2 Sep 2025 23:16:02 +0200 Subject: [PATCH 009/155] no more pain --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eef872837..77f7fa895 100644 --- a/README.md +++ b/README.md @@ -27,10 +27,10 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models - Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p). -- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx), then you will have to compile Sage 3, install it and cross your fingers that there isn't any crash. +- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ... -*update 8.31: one shouldnt talk about bugs if one dont want to attract bugs* +*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs* ### August 29 2025: WanGP v8.21 - Here Goes Your Weekend From 0871a3be580f1453a7c294d87b1882547ac69083 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 3 Sep 2025 09:50:52 +0200 Subject: [PATCH 010/155] a few fixes --- defaults/flux_dev_uso.json | 2 +- models/wan/wan_handler.py | 4 ++-- requirements.txt | 2 +- wgp.py | 39 +++++++++++++++++++++++--------------- 4 files changed, 28 insertions(+), 19 deletions(-) diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json index ab5ac5459..8b5dbb603 100644 --- a/defaults/flux_dev_uso.json +++ b/defaults/flux_dev_uso.json @@ -10,7 +10,7 @@ "reference_image": true, "flux-model": "flux-dev-uso" }, - "prompt": "add a hat", + "prompt": "the man is wearing a hat", "embedded_guidance_scale": 4, "resolution": "1024x1024", "batch_size": 1 diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 4d59aa354..6d91fe243 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -118,8 +118,8 @@ def query_model_def(base_model_type, model_def): # extra_model_def["at_least_one_image_ref_needed"] = True - if base_model_type in ["phantom_1.3B", "phantom_14B"]: - extra_model_def["one_image_ref_needed"] = True + # if base_model_type in ["phantom_1.3B", "phantom_14B"]: + # extra_model_def["one_image_ref_needed"] = True return extra_model_def diff --git a/requirements.txt b/requirements.txt index 7704d6db9..cdc21ad53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,7 +44,7 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.5.11 +mmgp==3.5.12 peft==0.15.0 matplotlib diff --git a/wgp.py b/wgp.py index 6527357b5..f11afa94f 100644 --- a/wgp.py +++ b/wgp.py @@ -59,8 +59,8 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.11" -WanGP_version = "8.31" +target_mmgp_version = "3.5.12" +WanGP_version = "8.32" settings_version = 2.28 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1361,6 +1361,14 @@ def _parse_args(): help="save proprocessed audio track with extract speakers for debugging or editing" ) + parser.add_argument( + "--vram-safety-coefficient", + type=float, + default=0.8, + help="max VRAM (between 0 and 1) that should be allocated to preloaded models" + ) + + parser.add_argument( "--share", action="store_true", @@ -2767,6 +2775,7 @@ def load_models(model_type, override_profile = -1): transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype perc_reserved_mem_max = args.perc_reserved_mem_max + vram_safety_coefficient = args.vram_safety_coefficient model_file_list = [model_filename] model_type_list = [model_type] module_type_list = [None] @@ -2803,7 +2812,7 @@ def load_models(model_type, override_profile = -1): loras_transformer = ["transformer"] if "transformer2" in pipe: loras_transformer += ["transformer2"] - offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) + offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , vram_safety_coefficient = vram_safety_coefficient , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_type = model_type @@ -4164,6 +4173,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri download_models() gen = get_gen_info(state) + original_process_status = None while True: with gen_lock: process_status = gen.get("process_status", None) @@ -6116,8 +6126,7 @@ def has_video_file_extension(filename): def has_image_file_extension(filename): extension = os.path.splitext(filename)[-1] - return extension in [".jpeg", ".jpg", ".png", ".webp", ".bmp", ".tiff"] - + return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) if files_to_load == None: @@ -7016,10 +7025,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non any_reference_image = False v2i_switch_supported = (vace or t2v or standin) and not image_outputs ti2v_2_2 = base_model_type in ["ti2v_2_2"] - + gallery_height = 350 def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ): with gr.Row(visible = visible) as gallery_row: - gallery_amg = AdvancedMediaGallery(media_mode="image", height=None, columns=4, label=label, initial = value , single_image_mode = single_image_mode ) + gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode ) gallery_amg.mount(update_form=update_form) return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements() @@ -7044,7 +7053,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) - video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) + video_source = gr.Video(label= "Video Source", height = gallery_height, visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) model_mode = gr.Dropdown(visible = False) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) any_video_source = True @@ -7062,7 +7071,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value ) - video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) + video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) if not diffusion_forcing: model_mode = gr.Dropdown( choices=[ @@ -7084,7 +7093,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V") image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) - video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),) + video_source = gr.Video(label= "Video Source", height = gallery_height, visible = True, value= ui_defaults.get("video_source", None),) model_mode = gr.Dropdown( choices=[ ("Pan Right", 1), @@ -7121,7 +7130,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl if hunyuan_i2v: video_source = gr.Video(value=None, visible=False) else: - video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) + video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) else: image_prompt_type = gr.Radio(choices=[("", "")], value="") image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) @@ -7311,8 +7320,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl visible = False, label="Start / Reference Images", scale = 2 ) - image_guide = gr.Image(label= "Control Image", type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None)) - video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + image_guide = gr.Image(label= "Control Image", height = gallery_height, type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None)) + video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False) keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False) @@ -7331,8 +7340,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False) video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False) any_image_mask = image_outputs and vace - image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("image_mask", None)) - video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) + image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None)) + video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("video_mask", None)) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image From a60eea23713c41dc89d7287383d5bf025334f494 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 3 Sep 2025 19:39:17 +0200 Subject: [PATCH 011/155] fixes and polish --- models/flux/flux_handler.py | 21 +++++++++--- models/flux/flux_main.py | 67 ++++++++++++++++++++++++++++++++----- models/flux/sampling.py | 1 - models/qwen/qwen_handler.py | 1 + models/wan/wan_handler.py | 2 ++ shared/utils/utils.py | 6 ++-- wgp.py | 35 ++++++++++++------- 7 files changed, 105 insertions(+), 28 deletions(-) diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index 162ec4c69..b8b7b9b42 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -26,12 +26,13 @@ def query_model_def(base_model_type, model_def): model_def_output["no_background_removal"] = True model_def_output["image_ref_choices"] = { - "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"), - ("Up to two Images are Style Images", "IJ")], - "default": "I", - "letters_filter": "IJ", + "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), + ("Up to two Images are Style Images", "KIJ")], + "default": "KI", + "letters_filter": "KIJ", "label": "Reference Images / Style Images" } + model_def_output["lock_image_refs_ratios"] = True return model_def_output @@ -107,6 +108,16 @@ def load_model(model_filename, model_type, base_model_type, model_def, quantizeT pipe["feature_embedder"] = flux_model.feature_embedder return flux_model, pipe + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + flux_model = model_def.get("flux-model", "flux-dev") + flux_uso = flux_model == "flux-dev-uso" + if flux_uso and settings_version < 2.29: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "I" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("I", "KI") + ui_defaults["video_prompt_type"] = video_prompt_type + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): flux_model = model_def.get("flux-model", "flux-dev") @@ -116,6 +127,6 @@ def update_default_settings(base_model_type, model_def, ui_defaults): }) if model_def.get("reference_image", False): ui_defaults.update({ - "video_prompt_type": "I" if flux_uso else "KI", + "video_prompt_type": "KI", }) diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 55a2b910a..9bb8e7379 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -9,6 +9,9 @@ from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack from .modules.layers import get_linear_split_map from transformers import SiglipVisionModel, SiglipImageProcessor +import torchvision.transforms.functional as TVF +import math +from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image from .util import ( aspect_ratio_to_height_width, @@ -21,6 +24,44 @@ from PIL import Image +def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1): + target_height_ref1 = int(target_height_ref1 // 64 * 64) + target_width_ref1 = int(target_width_ref1 // 64 * 64) + h, w = image.shape[-2:] + if h < target_height_ref1 or w < target_width_ref1: + # 计算长宽比 + aspect_ratio = w / h + if h < target_height_ref1: + new_h = target_height_ref1 + new_w = new_h * aspect_ratio + if new_w < target_width_ref1: + new_w = target_width_ref1 + new_h = new_w / aspect_ratio + else: + new_w = target_width_ref1 + new_h = new_w / aspect_ratio + if new_h < target_height_ref1: + new_h = target_height_ref1 + new_w = new_h * aspect_ratio + else: + aspect_ratio = w / h + tgt_aspect_ratio = target_width_ref1 / target_height_ref1 + if aspect_ratio > tgt_aspect_ratio: + new_h = target_height_ref1 + new_w = new_h * aspect_ratio + else: + new_w = target_width_ref1 + new_h = new_w / aspect_ratio + # 使用 TVF.resize 进行图像缩放 + image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w))) + # 计算中心裁剪的参数 + top = (image.shape[-2] - target_height_ref1) // 2 + left = (image.shape[-1] - target_width_ref1) // 2 + # 使用 TVF.crop 进行中心裁剪 + image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1) + return image + + def stitch_images(img1, img2): # Resize img2 to match img1's height width1, height1 = img1.size @@ -129,11 +170,11 @@ def generate( if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" device="cuda" flux_dev_uso = self.name in ['flux-dev-uso'] - image_stiching = not self.name in ['flux-dev-uso'] - + image_stiching = not self.name in ['flux-dev-uso'] #and False + # image_refs_relative_size = 100 + crop = False input_ref_images = [] if input_ref_images is None else input_ref_images[:] ref_style_imgs = [] - if "I" in video_prompt_type and len(input_ref_images) > 0: if flux_dev_uso : if "J" in video_prompt_type: @@ -148,7 +189,7 @@ def generate( if "K" in video_prompt_type : w, h = input_ref_images[0].size height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - + # actual rescale will happen in prepare_kontext for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] @@ -157,14 +198,24 @@ def generate( if "K" in video_prompt_type: # image latents tiling method w, h = input_ref_images[0].size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) + if crop : + img = convert_image_to_tensor(input_ref_images[0]) + img = resize_and_centercrop_image(img, height, width) + input_ref_images[0] = convert_tensor_to_image(img) + else: + height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) + input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) first_ref = 1 for i in range(first_ref,len(input_ref_images)): w, h = input_ref_images[i].size - image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) - input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + if crop: + img = convert_image_to_tensor(input_ref_images[i]) + img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100)) + input_ref_images[i] = convert_tensor_to_image(img) + else: + image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) + input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) else: input_ref_images = None diff --git a/models/flux/sampling.py b/models/flux/sampling.py index 5534e9f6b..f43ae15e0 100644 --- a/models/flux/sampling.py +++ b/models/flux/sampling.py @@ -153,7 +153,6 @@ def prepare_kontext( # Kontext is trained on specific resolutions, using one of them is recommended _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) - width = 2 * int(width / 16) height = 2 * int(height / 16) diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index c6004e164..6fc488aec 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -15,6 +15,7 @@ def query_model_def(base_model_type, model_def): ("Default", "default"), ("Lightning", "lightning")], "guidance_max_phases" : 1, + "lock_image_refs_ratios": True, } diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 6d91fe243..9adc3a884 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -117,6 +117,8 @@ def query_model_def(base_model_type, model_def): extra_model_def["no_background_removal"] = True # extra_model_def["at_least_one_image_ref_needed"] = True + if base_model_type in ["standin"] or vace_class: + extra_model_def["lock_image_refs_ratios"] = True # if base_model_type in ["phantom_1.3B", "phantom_14B"]: # extra_model_def["one_image_ref_needed"] = True diff --git a/shared/utils/utils.py b/shared/utils/utils.py index a55807a17..7ddf1eb6e 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -18,11 +18,11 @@ import tempfile import subprocess import json - +from functools import lru_cache from PIL import Image - +video_info_cache = [] def seed_everything(seed: int): random.seed(seed) np.random.seed(seed) @@ -77,7 +77,9 @@ def truncate_for_filesystem(s, max_bytes=255): else: r = m - 1 return s[:l] +@lru_cache(maxsize=100) def get_video_info(video_path): + global video_info_cache import cv2 cap = cv2.VideoCapture(video_path) diff --git a/wgp.py b/wgp.py index f11afa94f..a3cc51010 100644 --- a/wgp.py +++ b/wgp.py @@ -60,8 +60,8 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.5.12" -WanGP_version = "8.32" -settings_version = 2.28 +WanGP_version = "8.33" +settings_version = 2.29 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -3313,6 +3313,7 @@ def check(src, cond): if not all_letters(src, pos): return False if neg is not None and any_letters(src, neg): return False return True + image_outputs = configs.get("image_mode",0) == 1 map_video_prompt = {"V" : "Control Video", ("VA", "U") : "Mask Video", "I" : "Reference Images"} map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"} map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"} @@ -3364,6 +3365,7 @@ def check(src, cond): if multiple_submodels: video_guidance_scale += f" + Model Switch at {video_switch_threshold if video_model_switch_phase ==1 else video_switch_threshold2}" video_flow_shift = configs.get("flow_shift", None) + if image_outputs: video_flow_shift = None video_video_guide_outpainting = configs.get("video_guide_outpainting", "") video_outpainting = "" if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \ @@ -4545,7 +4547,8 @@ def remove_temp_filenames(temp_filenames_list): send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from shared.utils.utils import resize_and_remove_background - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace or standin) ) # no fit for vace ref images as it is done later + # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested + image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or model_def.get("lock_image_refs_ratios", False)) ) # no fit for vace ref images as it is done later update_task_thumbnails(task, locals()) send_cmd("output") joint_pass = boost ==1 #and profile != 1 and profile != 3 @@ -5912,6 +5915,8 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if target == "settings": return inputs + image_outputs = inputs.get("image_mode",0) == 1 + pop=[] if "force_fps" in inputs and len(inputs["force_fps"])== 0: pop += ["force_fps"] @@ -5977,7 +5982,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if guidance_max_phases < 3 or guidance_phases < 3: pop += ["guidance3_scale", "switch_threshold2", "model_switch_phase"] - if ltxv: + if ltxv or image_outputs: pop += ["flow_shift"] if model_def.get("no_negative_prompt", False) : @@ -6876,11 +6881,15 @@ def detect_auto_save_form(state, evt:gr.SelectData): return gr.update() def compute_video_length_label(fps, current_video_length): - return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s", + if fps is None: + return f"Number of frames" + else: + return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s", -def refresh_video_length_label(state, current_video_length): - fps = get_model_fps(get_base_model_type(state["model_type"])) - return gr.update(label= compute_video_length_label(fps, current_video_length)) +def refresh_video_length_label(state, current_video_length, force_fps, video_guide, video_source): + base_model_type = get_base_model_type(state["model_type"]) + computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source ) + return gr.update(label= compute_video_length_label(computed_fps, current_video_length)) def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): global inputs_names #, advanced @@ -7469,8 +7478,9 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97) + computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , video_guide, video_source ) video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length, - step=frames_step, label=compute_video_length_label(fps, current_video_length) , visible = True, interactive= True) + step=frames_step, label=compute_video_length_label(computed_fps, current_video_length) , visible = True, interactive= True) with gr.Row(visible = not lock_inference_steps) as inference_steps_row: num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True) @@ -7643,7 +7653,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row: - gr.Markdown("You may transfer the exising audio tracks of a Control Video") + gr.Markdown("You may transfer the existing audio tracks of a Control Video") audio_prompt_type_remux = gr.Dropdown( choices=[ ("No Remux", ""), @@ -7955,7 +7965,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab, sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col, - video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, + video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, @@ -7975,7 +7985,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai resolution.change(fn=record_last_resolution, inputs=[state, resolution]) - video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" ) + # video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" ) + gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last" ) guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) From e5abb1b9bc27cfe7f05562651aa23f14bba3d5e9 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 4 Sep 2025 01:10:42 +0200 Subject: [PATCH 012/155] fixed pytorch compilation --- models/wan/any2video.py | 7 ++++--- requirements.txt | 2 +- wgp.py | 6 +++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 8eb2ffef2..e7d54ef29 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -454,7 +454,8 @@ def generate(self, timesteps.append(0.) timesteps = [torch.tensor([t], device=self.device) for t in timesteps] if self.use_timestep_transform: - timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1] + timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1] + timesteps = torch.tensor(timesteps) sample_scheduler = None elif sample_solver == 'causvid': sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True) @@ -1016,8 +1017,8 @@ def clear(): if sample_solver == "euler": dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1]) - dt = dt / self.num_timesteps - latents = latents - noise_pred * dt[:, None, None, None, None] + dt = dt.item() / self.num_timesteps + latents = latents - noise_pred * dt else: latents = sample_scheduler.step( noise_pred[:, :, :target_shape[1]], diff --git a/requirements.txt b/requirements.txt index cdc21ad53..0ae0ab32f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,7 +44,7 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.5.12 +mmgp==3.6.0 peft==0.15.0 matplotlib diff --git a/wgp.py b/wgp.py index a3cc51010..c57cc48cd 100644 --- a/wgp.py +++ b/wgp.py @@ -59,8 +59,8 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.12" -WanGP_version = "8.33" +target_mmgp_version = "3.6.0" +WanGP_version = "8.34" settings_version = 2.29 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -8509,7 +8509,7 @@ def check(mode): ("Off", "" ), ], value= compile, - label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)", + label="Compile Transformer : up to 10-20% faster, useful only if multiple gens at same frames no / resolution", interactive= not lock_ui_compile ) From 56a51b79f00a2cd370b179412618c3167fae8e10 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 4 Sep 2025 01:25:19 +0200 Subject: [PATCH 013/155] add multitalk support for audio in mp4 --- models/wan/multitalk/multitalk.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py index fbf9175ae..b04f65f32 100644 --- a/models/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -59,7 +59,30 @@ def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=160 audio_emb = audio_emb.cpu().detach() return audio_emb - + +def extract_audio_from_video(filename, sample_rate): + raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav' + ffmpeg_command = [ + "ffmpeg", + "-y", + "-i", + str(filename), + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + "16000", + "-ac", + "2", + str(raw_audio_path), + ] + subprocess.run(ffmpeg_command, check=True) + human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate) + human_speech_array = loudness_norm(human_speech_array, sr) + os.remove(raw_audio_path) + + return human_speech_array + def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): ext = os.path.splitext(audio_path)[1].lower() if ext in ['.mp4', '.mov', '.avi', '.mkv']: From 13b001d4eaac075973bad596b119dfdce47d550c Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 4 Sep 2025 11:42:09 +0200 Subject: [PATCH 014/155] fixed forms errors --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index c57cc48cd..653683c15 100644 --- a/wgp.py +++ b/wgp.py @@ -7478,7 +7478,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97) - computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , video_guide, video_source ) + computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , ui_defaults.get("video_guide", None), ui_defaults.get("video_source", None)) video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length, step=frames_step, label=compute_video_length_label(computed_fps, current_video_length) , visible = True, interactive= True) From b37c1a26a8395d090ecf0c5e07d59db3d7b05cf4 Mon Sep 17 00:00:00 2001 From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com> Date: Fri, 5 Sep 2025 09:12:01 +0200 Subject: [PATCH 015/155] removed infinitetalk forced video to video --- models/wan/any2video.py | 1 - 1 file changed, 1 deletion(-) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index e7d54ef29..d394449dc 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -537,7 +537,6 @@ def generate(self, image_ref = input_frames[:, 0] if input_video is None: input_video = input_frames[:, 0:1] new_shot = "Q" in video_prompt_type - denoising_strength = 0.5 else: if pre_video_frame is None: new_shot = True From 8859e816d0a9969b4a1207f3eaf7333597cf1b26 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 5 Sep 2025 14:17:21 +0200 Subject: [PATCH 016/155] Take me to outer space --- README.md | 17 ++ models/hyvideo/hunyuan_handler.py | 4 + models/ltx_video/ltxv_handler.py | 1 + models/wan/any2video.py | 117 ++++++-------- models/wan/df_handler.py | 12 ++ models/wan/wan_handler.py | 54 +++++++ wgp.py | 256 ++++++++++++++---------------- 7 files changed, 256 insertions(+), 205 deletions(-) diff --git a/README.md b/README.md index 77f7fa895..eb497f1a9 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,23 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### September 5 2025: WanGP v8.4 - Take me to Outer Space +You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie. + +I have made it easier to do just that with *Qwen Edit* and *Wan*: +- **End Frames can now be combined with Continue a Video** (and not just a Start Frame) +- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window + +You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*) + +The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement. + +This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**. +So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close) + + +You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens). + ### September 2 2025: WanGP v8.31 - At last the pain stops - This single new feature should give you the strength to face all the potential bugs of this new release: diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index da60f72cb..9cbaea7d1 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -56,6 +56,10 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]: extra_model_def["one_image_ref_needed"] = True + + if base_model_type in ["hunyuan_i2v"]: + extra_model_def["image_prompt_types_allowed"] = "S" + return extra_model_def @staticmethod diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index d35bcd457..c89b69a43 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -24,6 +24,7 @@ def query_model_def(base_model_type, model_def): extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 8 extra_model_def["sliding_window"] = True + extra_model_def["image_prompt_types_allowed"] = "TSEV" return extra_model_def diff --git a/models/wan/any2video.py b/models/wan/any2video.py index e7d54ef29..bb91dc66b 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -537,7 +537,6 @@ def generate(self, image_ref = input_frames[:, 0] if input_video is None: input_video = input_frames[:, 0:1] new_shot = "Q" in video_prompt_type - denoising_strength = 0.5 else: if pre_video_frame is None: new_shot = True @@ -556,74 +555,59 @@ def generate(self, _ , preframes_count, height, width = input_video.shape input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) if infinitetalk: - image_for_clip = image_ref.to(input_video) + image_start = image_ref.to(input_video) control_pre_frames_count = 1 - control_video = image_for_clip.unsqueeze(1) + control_video = image_start.unsqueeze(1) else: - image_for_clip = input_video[:, -1] + image_start = input_video[:, -1] control_pre_frames_count = preframes_count control_video = input_video - lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - clip_image = resize_lanczos(image_for_clip, clip_image_size, clip_image_size)[:, None, :, :] - clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ]) - clip_image = None - else: - clip_context = None - enc = torch.concat( [control_video, torch.zeros( (3, frame_num-control_pre_frames_count, height, width), - device=self.device, dtype= self.VAE_dtype)], - dim = 1).to(self.device) - color_reference_frame = image_for_clip.unsqueeze(1).clone() + + color_reference_frame = image_start.unsqueeze(1).clone() else: - preframes_count = control_pre_frames_count = 1 - any_end_frame = image_end is not None - add_frames_for_end_image = any_end_frame and model_type == "i2v" - if any_end_frame: - if add_frames_for_end_image: - frame_num +=1 - lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) - trim_frames = 1 - + preframes_count = control_pre_frames_count = 1 height, width = image_start.shape[1:] + control_video = image_start.unsqueeze(1).to(self.device) + color_reference_frame = control_video.clone() - lat_h = round( - height // self.vae_stride[1] // - self.patch_size[1] * self.patch_size[1]) - lat_w = round( - width // self.vae_stride[2] // - self.patch_size[2] * self.patch_size[2]) - height = lat_h * self.vae_stride[1] - width = lat_w * self.vae_stride[2] - image_start_frame = image_start.unsqueeze(1).to(self.device) - color_reference_frame = image_start_frame.clone() - if image_end is not None: - img_end_frame = image_end.unsqueeze(1).to(self.device) - - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) - if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) - else: - clip_context = self.clip.visual([image_start[:, None, :, :]]) - else: - clip_context = None - - if any_end_frame: - enc= torch.concat([ - image_start_frame, - torch.zeros( (3, frame_num-2, height, width), device=self.device, dtype= self.VAE_dtype), - img_end_frame, - ], dim=1).to(self.device) + any_end_frame = image_end is not None + add_frames_for_end_image = any_end_frame and model_type == "i2v" + if any_end_frame: + color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot + if add_frames_for_end_image: + frame_num +=1 + lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) + trim_frames = 1 + + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + + if image_end is not None: + img_end_frame = image_end.unsqueeze(1).to(self.device) + + if hasattr(self, "clip"): + clip_image_size = self.clip.model.image_size + image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) + image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start + if model_type == "flf2v_720p": + clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) else: - enc= torch.concat([ - image_start_frame, - torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype) - ], dim=1).to(self.device) + clip_context = self.clip.visual([image_start[:, None, :, :]]) + else: + clip_context = None + + if any_end_frame: + enc= torch.concat([ + control_video, + torch.zeros( (3, frame_num-control_pre_frames_count-1, height, width), device=self.device, dtype= self.VAE_dtype), + img_end_frame, + ], dim=1).to(self.device) + else: + enc= torch.concat([ + control_video, + torch.zeros( (3, frame_num-control_pre_frames_count, height, width), device=self.device, dtype= self.VAE_dtype) + ], dim=1).to(self.device) - image_start = image_end = image_start_frame = img_end_frame = image_for_clip = image_ref = None + image_start = image_end = img_end_frame = image_ref = control_video = None msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) if any_end_frame: @@ -657,12 +641,11 @@ def generate(self, # Recam Master if recam: - # should be be in fact in input_frames since it is control video not a video to be extended target_camera = model_mode - height,width = input_video.shape[-2:] - input_video = input_video.to(dtype=self.dtype , device=self.device) - source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) - del input_video + height,width = input_frames.shape[-2:] + input_frames = input_frames.to(dtype=self.dtype , device=self.device) + source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + del input_frames # Process target camera (recammaster) from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) @@ -754,7 +737,9 @@ def generate(self, else: target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) - if multitalk and audio_proj != None: + if multitalk: + if audio_proj is None: + audio_proj = [ torch.zeros( (1, 1, 5, 12, 768 ), dtype=self.dtype, device=self.device), torch.zeros( (1, (frame_num - 1) // 4, 8, 12, 768 ), dtype=self.dtype, device=self.device) ] from .multitalk.multitalk import get_target_masks audio_proj = [audio.to(self.dtype) for audio in audio_proj] human_no = len(audio_proj[0]) diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index bc79e2ebc..82e704aaa 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -26,6 +26,18 @@ def query_model_def(base_model_type, model_def): extra_model_def["tea_cache"] = True extra_model_def["guidance_max_phases"] = 1 + extra_model_def["model_modes"] = { + "choices": [ + ("Synchronous", 0), + ("Asynchronous (better quality but around 50% extra steps added)", 5), + ], + "default": 0, + "label" : "Generation Type" + } + + extra_model_def["image_prompt_types_allowed"] = "TSEV" + + return extra_model_def @staticmethod diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 9adc3a884..3a5dd64a6 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -5,6 +5,9 @@ def test_class_i2v(base_model_type): return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] +def text_oneframe_overlap(base_model_type): + return test_class_i2v(base_model_type) and not test_multitalk(base_model_type) + def test_class_1_3B(base_model_type): return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] @@ -120,6 +123,37 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["standin"] or vace_class: extra_model_def["lock_image_refs_ratios"] = True + if base_model_type in ["recam_1.3B"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["model_modes"] = { + "choices": [ + ("Pan Right", 1), + ("Pan Left", 2), + ("Tilt Up", 3), + ("Tilt Down", 4), + ("Zoom In", 5), + ("Zoom Out", 6), + ("Translate Up (with rotation)", 7), + ("Translate Down (with rotation)", 8), + ("Arc Left (with rotation)", 9), + ("Arc Right (with rotation)", 10), + ], + "default": 1, + "label" : "Camera Movement Type" + } + if vace_class or base_model_type in ["infinitetalk"]: + image_prompt_types_allowed = "TVL" + elif base_model_type in ["ti2v_2_2"]: + image_prompt_types_allowed = "TSEVL" + elif i2v: + image_prompt_types_allowed = "SEVL" + else: + image_prompt_types_allowed = "" + extra_model_def["image_prompt_types_allowed"] = image_prompt_types_allowed + + if text_oneframe_overlap(base_model_type): + extra_model_def["sliding_window_defaults"] = { "overlap_min" : 1, "overlap_max" : 1, "overlap_step": 0, "overlap_default": 1} + # if base_model_type in ["phantom_1.3B", "phantom_14B"]: # extra_model_def["one_image_ref_needed"] = True @@ -251,6 +285,17 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): video_prompt_type = video_prompt_type.replace("U", "RU") ui_defaults["video_prompt_type"] = video_prompt_type + if settings_version < 2.31: + if base_model_type in "recam_1.3B": + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if not "V" in video_prompt_type: + video_prompt_type += "UV" + ui_defaults["video_prompt_type"] = video_prompt_type + ui_defaults["image_prompt_type"] = "" + + if text_oneframe_overlap(base_model_type): + ui_defaults["sliding_window_overlap"] = 1 + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ @@ -309,6 +354,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "image_prompt_type": "T", }) + if base_model_type in ["recam_1.3B"]: + ui_defaults.update({ + "video_prompt_type": "UV", + }) + + if text_oneframe_overlap(base_model_type): + ui_defaults.update["sliding_window_overlap"] = 1 + ui_defaults.update["color_correction_strength"]= 0 + if test_multitalk(base_model_type): ui_defaults["audio_guidance_scale"] = 4 diff --git a/wgp.py b/wgp.py index 653683c15..396e273e2 100644 --- a/wgp.py +++ b/wgp.py @@ -60,8 +60,8 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.34" -settings_version = 2.29 +WanGP_version = "8.4" +settings_version = 2.31 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -347,7 +347,7 @@ def process_prompt_and_add_tasks(state, model_choice): model_switch_phase = inputs["model_switch_phase"] switch_threshold = inputs["switch_threshold"] switch_threshold2 = inputs["switch_threshold2"] - + multi_prompts_gen_type = inputs["multi_prompts_gen_type"] if len(loras_multipliers) > 0: _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases) @@ -445,7 +445,7 @@ def process_prompt_and_add_tasks(state, model_choice): if "I" in video_prompt_type: if image_refs == None or len(image_refs) == 0: - gr.Info("You must provide at least one Refererence Image") + gr.Info("You must provide at least one Reference Image") return image_refs = clean_image_list(image_refs) if image_refs == None : @@ -511,9 +511,14 @@ def process_prompt_and_add_tasks(state, model_choice): if image_start == None : gr.Info("Start Image should be an Image") return + if multi_prompts_gen_type == 1 and len(image_start) > 1: + gr.Info("Only one Start Image is supported") + return else: image_start = None + if not any_letters(image_prompt_type, "SVL"): + image_prompt_type = image_prompt_type.replace("E", "") if "E" in image_prompt_type: if image_end == None or isinstance(image_end, list) and len(image_end) == 0: gr.Info("You must provide an End Image") @@ -522,32 +527,35 @@ def process_prompt_and_add_tasks(state, model_choice): if image_end == None : gr.Info("End Image should be an Image") return - if len(image_start) != len(image_end): - gr.Info("The number of Start and End Images should be the same ") - return + if multi_prompts_gen_type == 0: + if video_source is not None: + if len(image_end)> 1: + gr.Info("If a Video is to be continued and the option 'Each Text Prompt Will create a new generated Video' is set, there can be only one End Image") + return + elif len(image_start or []) != len(image_end or []): + gr.Info("The number of Start and End Images should be the same when the option 'Each Text Prompt Will create a new generated Video'") + return else: image_end = None if test_any_sliding_window(model_type) and image_mode == 0: if video_length > sliding_window_size: - full_video_length = video_length if video_source is None else video_length + sliding_window_overlap + full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1 extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation" no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated") if "recam" in model_filename: - if video_source == None: - gr.Info("You must provide a Source Video") + if video_guide == None: + gr.Info("You must provide a Control Video") return - - frames = get_resampled_video(video_source, 0, 81, get_computed_fps(force_fps, model_type , video_guide, video_source )) + computed_fps = get_computed_fps(force_fps, model_type , video_guide, video_source ) + frames = get_resampled_video(video_guide, 0, 81, computed_fps) if len(frames)<81: - gr.Info("Recammaster source video should be at least 81 frames once the resampling at 16 fps has been done") + gr.Info(f"Recammaster Control video should be at least 81 frames once the resampling at {computed_fps} fps has been done") return - - if "hunyuan_custom_custom_edit" in model_filename: if len(keep_frames_video_guide) > 0: gr.Info("Filtering Frames with this model is not supported") @@ -558,13 +566,13 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows") return - if image_end != None and len(image_end) > 1: - gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") - return + # if image_end != None and len(image_end) > 1: + # gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") + # return override_inputs = { "image_start": image_start[0] if image_start !=None and len(image_start) > 0 else None, - "image_end": image_end[0] if image_end !=None and len(image_end) > 0 else None, + "image_end": image_end, #[0] if image_end !=None and len(image_end) > 0 else None, "image_refs": image_refs, "audio_guide": audio_guide, "audio_guide2": audio_guide2, @@ -640,19 +648,21 @@ def process_prompt_and_add_tasks(state, model_choice): override_inputs["prompt"] = single_prompt inputs.update(override_inputs) add_video_task(**inputs) + new_prompts_count = len(prompts) else: + new_prompts_count = 1 override_inputs["prompt"] = "\n".join(prompts) inputs.update(override_inputs) add_video_task(**inputs) - gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0) + gen["prompts_max"] = new_prompts_count + gen.get("prompts_max",0) state["validate_success"] = 1 queue= gen.get("queue", []) return update_queue_data(queue) def get_preview_images(inputs): - inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] - labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"] + inputs_to_query = ["image_start", "video_source", "image_end", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] + labels = ["Start Image", "Video Source", "End Image", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"] start_image_data = None start_image_labels = [] end_image_data = None @@ -3454,6 +3464,8 @@ def convert_image(image): from PIL import ImageOps from typing import cast + if isinstance(image, str): + image = Image.open(image) image = image.convert('RGB') return cast(Image, ImageOps.exif_transpose(image)) @@ -4506,7 +4518,7 @@ def remove_temp_filenames(temp_filenames_list): if test_any_sliding_window(model_type) : if video_source is not None: - current_video_length += sliding_window_overlap + current_video_length += sliding_window_overlap - 1 sliding_window = current_video_length > sliding_window_size reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) else: @@ -4690,8 +4702,7 @@ def remove_temp_filenames(temp_filenames_list): while not abort: enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* get_model_fps(base_model_type)+1) or RIFLEx_setting == 1 - if sliding_window: - prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] + prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] new_extra_windows = gen.get("extra_windows",0) gen["extra_windows"] = 0 extra_windows += new_extra_windows @@ -4722,15 +4733,13 @@ def remove_temp_filenames(temp_filenames_list): image_start_tensor = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) image_start_tensor = convert_image_to_tensor(image_start_tensor) pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1) - if image_end is not None: - image_end_tensor = image_end.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - image_end_tensor = convert_image_to_tensor(image_end_tensor) else: if "L" in image_prompt_type: refresh_preview["video_source"] = get_video_frame(video_source, 0) prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size ) prefix_video = prefix_video.permute(3, 0, 1, 2) prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w + new_height, new_width = prefix_video.shape[-2:] pre_video_guide = prefix_video[:, -reuse_frames:] pre_video_frame = convert_tensor_to_image(prefix_video[:, -1]) source_video_overlap_frames_count = pre_video_guide.shape[1] @@ -4739,7 +4748,14 @@ def remove_temp_filenames(temp_filenames_list): image_size = pre_video_guide.shape[-2:] sample_fit_canvas = None guide_start_frame = prefix_video.shape[1] - + if image_end is not None: + image_end_list= image_end if isinstance(image_end, list) else [image_end] + if len(image_end_list) >= window_no: + new_height, new_width = image_size + image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + image_end_tensor = convert_image_to_tensor(image_end_tensor) + image_end_list= None + window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames) alignment_shift = source_video_frames_count if reset_control_aligment else 0 @@ -4797,7 +4813,7 @@ def remove_temp_filenames(temp_filenames_list): image_size = src_video.shape[-2:] sample_fit_canvas = None - elif "G" in video_prompt_type: # video to video + else: # video to video video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) if video_guide_processed is None: src_video = pre_video_guide @@ -5298,7 +5314,7 @@ def process_tasks(state): while True: with gen_lock: process_status = gen.get("process_status", None) - if process_status is None: + if process_status is None or process_status == "process:main": gen["process_status"] = "process:main" break time.sleep(1) @@ -6570,11 +6586,13 @@ def any_letters(source_str, letters): return True return False -def filter_letters(source_str, letters): +def filter_letters(source_str, letters, default= ""): ret = "" for letter in letters: if letter in source_str: ret += letter + if len(ret) == 0: + return default return ret def add_to_sequence(source_str, letters): @@ -6601,9 +6619,18 @@ def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_typ audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources) return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)) -def refresh_image_prompt_type(state, image_prompt_type): - any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0 - return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source) +def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_radio): + image_prompt_type = del_in_sequence(image_prompt_type, "VLTS") + image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) + any_video_source = len(filter_letters(image_prompt_type, "VL"))>0 + end_visible = any_letters(image_prompt_type, "SVL") + return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible) + +def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox): + image_prompt_type = del_in_sequence(image_prompt_type, "E") + if end_checkbox: image_prompt_type += "E" + image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) + return image_prompt_type, gr.update(visible = "E" in image_prompt_type ) def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs): model_type = state["model_type"] @@ -6680,9 +6707,12 @@ def get_prompt_labels(multi_prompts_gen_type, image_outputs = False): new_line_text = "each new line of prompt will be used for a window" if multi_prompts_gen_type != 0 else "each new line of prompt will generate " + ("a new image" if image_outputs else "a new video") return "Prompts (" + new_line_text + ", # lines = comments, ! lines = macros)", "Prompts (" + new_line_text + ", # lines = comments)" +def get_image_end_label(multi_prompts_gen_type): + return "Images as ending points for new Videos in the Generation Queue" if multi_prompts_gen_type == 0 else "Images as ending points for each new Window of the same Video Generation" + def refresh_prompt_labels(multi_prompts_gen_type, image_mode): prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode == 1) - return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label) + return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label), gr.update(label=get_image_end_label(multi_prompts_gen_type)) def show_preview_column_modal(state, column_no): column_no = int(column_no) @@ -7054,101 +7084,46 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"): pass - with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column: - if vace or infinitetalk: - image_prompt_type_value= ui_defaults.get("image_prompt_type","") - image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value - image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3) - - image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) - image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) - video_source = gr.Video(label= "Video Source", height = gallery_height, visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) - model_mode = gr.Dropdown(visible = False) - keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") + model_mode_choices = model_def.get("model_modes", None) + with gr.Column(visible= len(image_prompt_types_allowed)> 0 or model_mode_choices is not None) as image_prompt_column: + image_prompt_type_value= ui_defaults.get("image_prompt_type","") + image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) + image_prompt_type_choices = [] + if "T" in image_prompt_types_allowed: + image_prompt_type_choices += [("Text Prompt Only", "")] + any_start_image = True + if "S" in image_prompt_types_allowed: + image_prompt_type_choices += [("Start Video with Image", "S")] + any_start_image = True + if "V" in image_prompt_types_allowed: any_video_source = True - - elif diffusion_forcing or ltxv or ti2v_2_2: - image_prompt_type_value= ui_defaults.get("image_prompt_type","T") - # image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Start and End Video with Images", "SE"), ("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3) - image_prompt_type_choices = [("Text Prompt Only", "T"),("Start Video with Image", "S")] - if ltxv: - image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] - if sliding_window_enabled: - any_video_source = True - image_prompt_type_choices += [("Continue Video", "V")] - image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) - - image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value ) - video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - if not diffusion_forcing: - model_mode = gr.Dropdown( - choices=[ - ], value=None, - visible= False - ) - else: - model_mode = gr.Dropdown( - choices=[ - ("Synchronous", 0), - ("Asynchronous (better quality but around 50% extra steps added)", 5), - ], - value=ui_defaults.get("model_mode", 0), - label="Generation Type", scale = 3, - visible= True - ) - keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" ) - elif recammaster: - image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V") - image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) - image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) - video_source = gr.Video(label= "Video Source", height = gallery_height, visible = True, value= ui_defaults.get("video_source", None),) - model_mode = gr.Dropdown( - choices=[ - ("Pan Right", 1), - ("Pan Left", 2), - ("Tilt Up", 3), - ("Tilt Down", 4), - ("Zoom In", 5), - ("Zoom Out", 6), - ("Translate Up (with rotation)", 7), - ("Translate Down (with rotation)", 8), - ("Arc Left (with rotation)", 9), - ("Arc Right (with rotation)", 10), - ], - value=ui_defaults.get("model_mode", 1), - label="Camera Movement Type", scale = 3, - visible= True - ) - keep_frames_video_source = gr.Text(visible=False) - else: - if test_class_i2v(model_type) or hunyuan_i2v: - # image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" ) - image_prompt_type_value= ui_defaults.get("image_prompt_type","S" ) - image_prompt_type_choices = [("Start Video with Image", "S")] - image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] - if not hunyuan_i2v: - any_video_source = True - image_prompt_type_choices += [("Continue Video", "V")] - - image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) - any_start_image = True - any_end_image = True - image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new videos", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - image_end_row, image_end, image_end_extra = get_image_gallery(label= "Images as ending points for new videos", value = ui_defaults.get("image_end", None), visible= "E" in image_prompt_type_value ) - if hunyuan_i2v: - video_source = gr.Video(value=None, visible=False) + image_prompt_type_choices += [("Continue Video", "V")] + if "L" in image_prompt_types_allowed: + any_video_source = True + image_prompt_type_choices += [("Continue Last Video", "L")] + with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group: + with gr.Row(): + image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL") + if len(image_prompt_type_choices) > 0: + image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value =filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]), label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3) else: - video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - else: - image_prompt_type = gr.Radio(choices=[("", "")], value="") - image_start_row, image_start, image_start_extra = get_image_gallery(visible = False ) - image_end_row, image_end, image_end_extra = get_image_gallery(visible = False ) - video_source = gr.Video(value=None, visible=False) + image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False) + if "E" in image_prompt_types_allowed: + image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_value, "SVL") and not image_outputs , scale= 1) + any_end_image = True + else: + image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1) + image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new Videos in the Generation Queue", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) + video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) + image_end_row, image_end, image_end_extra = get_image_gallery(label= get_image_end_label(ui_defaults.get("multi_prompts_gen_type", 0)), value = ui_defaults.get("image_end", None), visible= any_letters(image_prompt_type_value, "SVL") and ("E" in image_prompt_type_value) ) + if model_mode_choices is None: model_mode = gr.Dropdown(value=None, visible=False) - keep_frames_video_source = gr.Text(visible=False) + else: + model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=ui_defaults.get("model_mode", model_mode_choices["default"]), label=model_mode_choices["label"], visible=True) + keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) - with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column: + with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) any_control_video = True @@ -7208,12 +7183,12 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl value=filter_letters(video_prompt_type_value, "PDSLCMUV"), label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True, ) - elif infinitetalk: - video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False) + elif recammaster: + video_prompt_type_video_guide = gr.Dropdown(value="UV", choices = [("Control Video","UV")], visible=False) else: any_control_video = False any_control_image = False - video_prompt_type_video_guide = gr.Dropdown(visible= False) + video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False) if infinitetalk: video_prompt_type_video_guide_alt = gr.Dropdown( @@ -7228,6 +7203,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl value=filter_letters(video_prompt_type_value, "RGUVQKI"), label="Video to Video", scale = 3, visible= True, show_label= False, ) + any_control_video = any_control_image = True else: video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False) @@ -7761,7 +7737,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) elif ltxv: sliding_window_size = gr.Slider(41, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size") - sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") + sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0) sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=8, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7772,8 +7748,9 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) else: # Vace, Multitalk + sliding_window_defaults = model_def.get("sliding_window_defaults", {}) sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size") - sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") + sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)") sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7790,13 +7767,13 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai multi_prompts_gen_type = gr.Dropdown( choices=[ - ("Will create new generated Video", 0), + ("Will create a new generated Video added to the Generation Queue", 0), ("Will be used for a new Sliding Window of the same Video Generation", 1), ], value=ui_defaults.get("multi_prompts_gen_type",0), visible=True, scale = 1, - label="Text Prompts separated by a Carriage Return" + label="Images & Text Prompts separated by a Carriage Return" if (any_start_image or any_end_image) else "Text Prompts separated by a Carriage Return" ) with gr.Tab("Misc.", visible = True) as misc_tab: @@ -7962,7 +7939,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai hidden_countdown_state = gr.Number(value=-1, visible=False, elem_id="hidden_countdown_state_num") single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn") - extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, + extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, image_prompt_type_group, image_prompt_type_radio, image_prompt_type_endcheckbox, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab, sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col, video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row, @@ -7981,23 +7958,24 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai target_settings = gr.Text(value = "settings", interactive= False, visible= False) last_choice = gr.Number(value =-1, interactive= False, visible= False) - resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution]) + resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution], show_progress="hidden") resolution.change(fn=record_last_resolution, inputs=[state, resolution]) # video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" ) - gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last" ) + gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last", show_progress="hidden" ) guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) - image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start_row, image_end_row, video_source, keep_frames_video_source] ) + image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" ) + image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] ) # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col]) video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ]) video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) - multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt]) + multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden") video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_left.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_left,gr.State(2)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) From 99fd9aea32937d039fe54542afc0d50e18239f87 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 5 Sep 2025 15:56:42 +0200 Subject: [PATCH 017/155] fixed settings update --- models/wan/wan_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 3a5dd64a6..bc2eb396e 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -360,8 +360,8 @@ def update_default_settings(base_model_type, model_def, ui_defaults): }) if text_oneframe_overlap(base_model_type): - ui_defaults.update["sliding_window_overlap"] = 1 - ui_defaults.update["color_correction_strength"]= 0 + ui_defaults["sliding_window_overlap"] = 1 + ui_defaults["color_correction_strength"]= 0 if test_multitalk(base_model_type): ui_defaults["audio_guidance_scale"] = 4 From 836777eb55ca7c9a752b9a2703b0a55b75366de9 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 5 Sep 2025 20:57:09 +0200 Subject: [PATCH 018/155] fixed End Frames checkbox not always visible --- models/wan/wan_handler.py | 2 +- wgp.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index bc2eb396e..8316d6d90 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -144,7 +144,7 @@ def query_model_def(base_model_type, model_def): if vace_class or base_model_type in ["infinitetalk"]: image_prompt_types_allowed = "TVL" elif base_model_type in ["ti2v_2_2"]: - image_prompt_types_allowed = "TSEVL" + image_prompt_types_allowed = "TSEV" elif i2v: image_prompt_types_allowed = "SEVL" else: diff --git a/wgp.py b/wgp.py index 396e273e2..1eb3fcae8 100644 --- a/wgp.py +++ b/wgp.py @@ -7105,12 +7105,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group: with gr.Row(): image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL") + image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]) if len(image_prompt_type_choices) > 0: - image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value =filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]), label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3) + image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3) else: image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False) if "E" in image_prompt_types_allowed: - image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_value, "SVL") and not image_outputs , scale= 1) + image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1) any_end_image = True else: image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1) From cf65889e2e1094902e08434e87b18049b9fe76b7 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 5 Sep 2025 21:59:26 +0200 Subject: [PATCH 019/155] fixed End Frames checkbox not always visible --- models/wan/wan_handler.py | 8 ++++++++ wgp.py | 6 +++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 8316d6d90..20c3594ce 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -296,11 +296,19 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): if text_oneframe_overlap(base_model_type): ui_defaults["sliding_window_overlap"] = 1 + if settings_version < 2.32: + image_prompt_type = ui_defaults.get("image_prompt_type", "") + if test_class_i2v(base_model_type) and len(image_prompt_type) == 0: + ui_defaults["image_prompt_type"] = "S" + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "sample_solver": "unipc", }) + if test_class_i2v(base_model_type): + ui_defaults["image_prompt_type"] = "S" + if base_model_type in ["fantasy"]: ui_defaults.update({ "audio_guidance_scale": 5.0, diff --git a/wgp.py b/wgp.py index 1eb3fcae8..124ae5e6c 100644 --- a/wgp.py +++ b/wgp.py @@ -61,7 +61,7 @@ target_mmgp_version = "3.6.0" WanGP_version = "8.4" -settings_version = 2.31 +settings_version = 2.32 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -6142,11 +6142,11 @@ def eject_video_from_gallery(state, input_file_list, choice): return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) def has_video_file_extension(filename): - extension = os.path.splitext(filename)[-1] + extension = os.path.splitext(filename)[-1].lower() return extension in [".mp4"] def has_image_file_extension(filename): - extension = os.path.splitext(filename)[-1] + extension = os.path.splitext(filename)[-1].lower() return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) From 242ae50d8368ac2da00adefa3365f69ca01ee03a Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 5 Sep 2025 22:05:00 +0200 Subject: [PATCH 020/155] fixed End Frames checkbox not always visible --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index 124ae5e6c..7db02363f 100644 --- a/wgp.py +++ b/wgp.py @@ -7105,8 +7105,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group: with gr.Row(): image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL") - image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]) if len(image_prompt_type_choices) > 0: + image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]) image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3) else: image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False) From 0649686e9d785b734f209967128f0849e160f5f6 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 5 Sep 2025 22:14:12 +0200 Subject: [PATCH 021/155] fixed End Frames checkbox not always visible --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index 7db02363f..4f1a10dd9 100644 --- a/wgp.py +++ b/wgp.py @@ -7105,8 +7105,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group: with gr.Row(): image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL") + image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1] if len(image_prompt_type_choices) > 0 else "") if len(image_prompt_type_choices) > 0: - image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1]) image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3) else: image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False) From 66a07db16b841a908008f953849758cf3de6cf37 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sat, 6 Sep 2025 00:08:45 +0200 Subject: [PATCH 022/155] fixed End Frames checkbox not always visible --- models/wan/wan_handler.py | 4 +++- wgp.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 20c3594ce..5d5c27c8e 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -144,7 +144,9 @@ def query_model_def(base_model_type, model_def): if vace_class or base_model_type in ["infinitetalk"]: image_prompt_types_allowed = "TVL" elif base_model_type in ["ti2v_2_2"]: - image_prompt_types_allowed = "TSEV" + image_prompt_types_allowed = "TSVL" + elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]: + image_prompt_types_allowed = "SVL" elif i2v: image_prompt_types_allowed = "SEVL" else: diff --git a/wgp.py b/wgp.py index 4f1a10dd9..9bbf1ff89 100644 --- a/wgp.py +++ b/wgp.py @@ -6623,7 +6623,8 @@ def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_ image_prompt_type = del_in_sequence(image_prompt_type, "VLTS") image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) any_video_source = len(filter_letters(image_prompt_type, "VL"))>0 - end_visible = any_letters(image_prompt_type, "SVL") + model_def = get_model_def(state["model_type"]) + end_visible = any_letters(image_prompt_type, "SVL") and "E" in model_def.get("image_prompt_types_allowed","") return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible) def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox): From f9f63cbc79c364ac146d833292f0ab89e199228c Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 9 Sep 2025 21:41:35 +0200 Subject: [PATCH 023/155] intermediate commit --- defaults/flux_dev_kontext.json | 2 - defaults/flux_dev_uso.json | 2 - defaults/qwen_image_edit_20B.json | 4 +- models/flux/flux_handler.py | 37 +- models/flux/flux_main.py | 69 +-- models/hyvideo/hunyuan.py | 5 - models/hyvideo/hunyuan_handler.py | 21 + models/ltx_video/ltxv.py | 3 - models/ltx_video/ltxv_handler.py | 9 + models/qwen/pipeline_qwenimage.py | 65 ++- models/qwen/qwen_handler.py | 20 +- models/qwen/qwen_main.py | 10 +- models/wan/any2video.py | 41 +- models/wan/df_handler.py | 2 +- models/wan/wan_handler.py | 72 ++- preprocessing/matanyone/app.py | 69 ++- requirements.txt | 2 +- shared/gradio/gallery.py | 141 +++-- shared/utils/utils.py | 53 +- wgp.py | 833 +++++++++++++++++------------- 20 files changed, 893 insertions(+), 567 deletions(-) diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json index 894591880..20b6bc4d3 100644 --- a/defaults/flux_dev_kontext.json +++ b/defaults/flux_dev_kontext.json @@ -7,8 +7,6 @@ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" ], - "image_outputs": true, - "reference_image": true, "flux-model": "flux-dev-kontext" }, "prompt": "add a hat", diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json index 8b5dbb603..0cd7b828d 100644 --- a/defaults/flux_dev_uso.json +++ b/defaults/flux_dev_uso.json @@ -6,8 +6,6 @@ "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], "URLs": "flux", "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"], - "image_outputs": true, - "reference_image": true, "flux-model": "flux-dev-uso" }, "prompt": "the man is wearing a hat", diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json index 2b24c72da..79b8b245d 100644 --- a/defaults/qwen_image_edit_20B.json +++ b/defaults/qwen_image_edit_20B.json @@ -9,9 +9,7 @@ ], "attention": { "<89": "sdpa" - }, - "reference_image": true, - "image_outputs": true + } }, "prompt": "add a hat", "resolution": "1280x720", diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index b8b7b9b42..c468d5a9a 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -13,28 +13,41 @@ def query_model_def(base_model_type, model_def): flux_schnell = flux_model == "flux-schnell" flux_chroma = flux_model == "flux-chroma" flux_uso = flux_model == "flux-dev-uso" - model_def_output = { + flux_kontext = flux_model == "flux-dev-kontext" + + extra_model_def = { "image_outputs" : True, "no_negative_prompt" : not flux_chroma, } if flux_chroma: - model_def_output["guidance_max_phases"] = 1 + extra_model_def["guidance_max_phases"] = 1 elif not flux_schnell: - model_def_output["embedded_guidance"] = True + extra_model_def["embedded_guidance"] = True if flux_uso : - model_def_output["any_image_refs_relative_size"] = True - model_def_output["no_background_removal"] = True - - model_def_output["image_ref_choices"] = { + extra_model_def["any_image_refs_relative_size"] = True + extra_model_def["no_background_removal"] = True + extra_model_def["image_ref_choices"] = { "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), ("Up to two Images are Style Images", "KIJ")], "default": "KI", "letters_filter": "KIJ", "label": "Reference Images / Style Images" } - model_def_output["lock_image_refs_ratios"] = True + + if flux_kontext: + extra_model_def["image_ref_choices"] = { + "choices": [ + ("None", ""), + ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"), + ("Conditional Images are People / Objects", "I"), + ], + "letters_filter": "KI", + } + - return model_def_output + extra_model_def["lock_image_refs_ratios"] = True + + return extra_model_def @staticmethod def query_supported_types(): @@ -122,10 +135,12 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): def update_default_settings(base_model_type, model_def, ui_defaults): flux_model = model_def.get("flux-model", "flux-dev") flux_uso = flux_model == "flux-dev-uso" + flux_kontext = flux_model == "flux-dev-kontext" ui_defaults.update({ "embedded_guidance": 2.5, - }) - if model_def.get("reference_image", False): + }) + + if flux_kontext or flux_uso: ui_defaults.update({ "video_prompt_type": "KI", }) diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 9bb8e7379..4d7c67d41 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -24,44 +24,6 @@ from PIL import Image -def resize_and_centercrop_image(image, target_height_ref1, target_width_ref1): - target_height_ref1 = int(target_height_ref1 // 64 * 64) - target_width_ref1 = int(target_width_ref1 // 64 * 64) - h, w = image.shape[-2:] - if h < target_height_ref1 or w < target_width_ref1: - # 计算长宽比 - aspect_ratio = w / h - if h < target_height_ref1: - new_h = target_height_ref1 - new_w = new_h * aspect_ratio - if new_w < target_width_ref1: - new_w = target_width_ref1 - new_h = new_w / aspect_ratio - else: - new_w = target_width_ref1 - new_h = new_w / aspect_ratio - if new_h < target_height_ref1: - new_h = target_height_ref1 - new_w = new_h * aspect_ratio - else: - aspect_ratio = w / h - tgt_aspect_ratio = target_width_ref1 / target_height_ref1 - if aspect_ratio > tgt_aspect_ratio: - new_h = target_height_ref1 - new_w = new_h * aspect_ratio - else: - new_w = target_width_ref1 - new_h = new_w / aspect_ratio - # 使用 TVF.resize 进行图像缩放 - image = TVF.resize(image, (math.ceil(new_h), math.ceil(new_w))) - # 计算中心裁剪的参数 - top = (image.shape[-2] - target_height_ref1) // 2 - left = (image.shape[-1] - target_width_ref1) // 2 - # 使用 TVF.crop 进行中心裁剪 - image = TVF.crop(image, top, left, target_height_ref1, target_width_ref1) - return image - - def stitch_images(img1, img2): # Resize img2 to match img1's height width1, height1 = img1.size @@ -171,8 +133,6 @@ def generate( device="cuda" flux_dev_uso = self.name in ['flux-dev-uso'] image_stiching = not self.name in ['flux-dev-uso'] #and False - # image_refs_relative_size = 100 - crop = False input_ref_images = [] if input_ref_images is None else input_ref_images[:] ref_style_imgs = [] if "I" in video_prompt_type and len(input_ref_images) > 0: @@ -186,36 +146,15 @@ def generate( if image_stiching: # image stiching method stiched = input_ref_images[0] - if "K" in video_prompt_type : - w, h = input_ref_images[0].size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - # actual rescale will happen in prepare_kontext for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] else: - first_ref = 0 - if "K" in video_prompt_type: - # image latents tiling method - w, h = input_ref_images[0].size - if crop : - img = convert_image_to_tensor(input_ref_images[0]) - img = resize_and_centercrop_image(img, height, width) - input_ref_images[0] = convert_tensor_to_image(img) - else: - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) - first_ref = 1 - - for i in range(first_ref,len(input_ref_images)): + # latents stiching with resize + for i in range(len(input_ref_images)): w, h = input_ref_images[i].size - if crop: - img = convert_image_to_tensor(input_ref_images[i]) - img = resize_and_centercrop_image(img, int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100)) - input_ref_images[i] = convert_tensor_to_image(img) - else: - image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) - input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) + input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) else: input_ref_images = None diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py index a38a7bd0d..181a9a774 100644 --- a/models/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -861,11 +861,6 @@ def generate( freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx) else: if self.avatar: - w, h = input_ref_images.size - target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas) - if target_width != w or target_height != h: - input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) - concat_dict = {'mode': 'timecat', 'bias': -1} freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) else: diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index 9cbaea7d1..487e76d99 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -51,6 +51,23 @@ def query_model_def(base_model_type, model_def): extra_model_def["tea_cache"] = True extra_model_def["mag_cache"] = True + if base_model_type in ["hunyuan_custom_edit"]: + extra_model_def["guide_preprocessing"] = { + "selection": ["MV", "PMV"], + } + + extra_model_def["mask_preprocessing"] = { + "selection": ["A", "NA"], + "default" : "NA" + } + + if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Reference Image", "I")], + "letters_filter":"I", + "visible": False, + } + if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]: @@ -141,6 +158,10 @@ def load_model(model_filename, model_type = None, base_model_type = None, model return hunyuan_model, pipe + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + pass + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults["embedded_guidance_scale"]= 6.0 diff --git a/models/ltx_video/ltxv.py b/models/ltx_video/ltxv.py index e71ac4fce..db143fc6d 100644 --- a/models/ltx_video/ltxv.py +++ b/models/ltx_video/ltxv.py @@ -300,9 +300,6 @@ def generate( prefix_size, height, width = input_video.shape[-3:] else: if image_start != None: - frame_width, frame_height = image_start.size - if fit_into_canvas != None: - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) conditioning_media_paths.append(image_start.unsqueeze(1)) conditioning_start_frames.append(0) conditioning_control_frames.append(False) diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index c89b69a43..e44c98399 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -26,6 +26,15 @@ def query_model_def(base_model_type, model_def): extra_model_def["sliding_window"] = True extra_model_def["image_prompt_types_allowed"] = "TSEV" + extra_model_def["guide_preprocessing"] = { + "selection": ["", "PV", "DV", "EV", "V"], + "labels" : { "V": "Use LTXV raw format"} + } + + extra_model_def["mask_preprocessing"] = { + "selection": ["", "A", "NA", "XA", "XNA"], + } + return extra_model_def @staticmethod diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 07bdbd404..905b86493 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -28,7 +28,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers import FlowMatchEulerDiscreteScheduler from PIL import Image -from shared.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image XLA_AVAILABLE = False @@ -563,6 +563,8 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, image = None, + image_mask = None, + denoising_strength = 0, callback=None, pipeline=None, loras_slists=None, @@ -694,14 +696,33 @@ def __call__( image_width = image_width // multiple_of * multiple_of image_height = image_height // multiple_of * multiple_of ref_height, ref_width = 1568, 672 - if height * width < ref_height * ref_width: ref_height , ref_width = height , width - if image_height * image_width > ref_height * ref_width: - image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) - image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) + if image_mask is None: + if height * width < ref_height * ref_width: ref_height , ref_width = height , width + if image_height * image_width > ref_height * ref_width: + image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) + if (image_width,image_height) != image.size: + image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) + image_mask_latents = None + else: + # _, image_width, image_height = min( + # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS + # ) + image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of) + # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) + height, width = image_height, image_width + image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS)) + image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] + image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) + convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) + prompt_image = image - image = self.image_processor.preprocess(image, image_height, image_width) - image = image.unsqueeze(2) + if image.size != (image_width, image_height): + image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + + image.save("nnn.png") + image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None @@ -744,6 +765,8 @@ def __call__( generator, latents, ) + original_image_latents = None if image_latents is None else image_latents.clone() + if image is not None: img_shapes = [ [ @@ -788,6 +811,15 @@ def __call__( negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) + morph = False + if image_mask_latents is not None and denoising_strength <= 1.: + first_step = int(len(timesteps) * (1. - denoising_strength)) + if not morph: + latent_noise_factor = timesteps[first_step]/1000 + latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + timesteps = timesteps[first_step:] + self.scheduler.timesteps = timesteps + self.scheduler.sigmas= self.scheduler.sigmas[first_step:] # 6. Denoising loop self.scheduler.set_begin_index(0) @@ -797,10 +829,15 @@ def __call__( update_loras_slists(self.transformer, loras_slists, updated_num_steps) callback(-1, None, True, override_num_inference_steps = updated_num_steps) + for i, t in enumerate(timesteps): if self.interrupt: continue + if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph: + latent_noise_factor = t/1000 + latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor + self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -865,6 +902,12 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if image_mask_latents is not None: + next_t = timesteps[i+1] if i List[Any]: def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : # Mirror the selected index into state and the gallery (server-side selected_index) + + st = get_state(state) + last_time = st.get("last_time", None) + if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery) + # print(f"ignored:{time.time()}, real {st['selected']}") + return gr.update(selected_index=st["selected"]), st + idx = None if evt is not None and hasattr(evt, "index"): ix = evt.index @@ -220,17 +232,28 @@ def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : idx = ix[0] * max(1, int(self.columns)) + ix[1] else: idx = ix[0] - st = get_state(state) n = len(get_list(gallery)) sel = idx if (idx is not None and 0 <= idx < n) else None + # print(f"image selected evt index:{sel}/{evt.selected}") st["selected"] = sel - # return gr.update(selected_index=sel), st - # return gr.update(), st - return st + return gr.update(), st + + def _on_upload(self, value: List[Any], state: Dict[str, Any]) : + # Fires when users upload via the Gallery itself. + # items_filtered = self._filter_items_by_mode(list(value or [])) + items_filtered = list(value or []) + st = get_state(state) + new_items = self._paths_from_payload(items_filtered) + st["items"] = new_items + new_sel = len(new_items) - 1 + st["selected"] = new_sel + record_last_action(st,"add") + return gr.update(selected_index=new_sel), st def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : # Fires when users add/drag/drop/delete via the Gallery itself. - items_filtered = self._filter_items_by_mode(list(value or [])) + # items_filtered = self._filter_items_by_mode(list(value or [])) + items_filtered = list(value or []) st = get_state(state) st["items"] = items_filtered # Keep selection if still valid, else default to last @@ -240,10 +263,9 @@ def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : else: new_sel = old_sel st["selected"] = new_sel - # return gr.update(value=items_filtered, selected_index=new_sel), st - # return gr.update(value=items_filtered), st - - return gr.update(), st + st["last_action"] ="gallery_change" + # print(f"gallery change: set sel {new_sel}") + return gr.update(selected_index=new_sel), st def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): """ @@ -252,7 +274,8 @@ def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): and re-selects the last inserted item. """ # New items (respect image/video mode) - new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload)) + # new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload)) + new_items = self._paths_from_payload(files_payload) st = get_state(state) cur: List[Any] = get_list(gallery) @@ -298,30 +321,6 @@ def key_of(it: Any) -> Optional[str]: if k is not None: seen_new.add(k) - # Remove any existing occurrences of the incoming items from current list, - # BUT keep the currently selected item even if it's also in incoming. - cur_clean: List[Any] = [] - # sel_item = cur[sel] if (sel is not None and 0 <= sel < len(cur)) else None - # for idx, it in enumerate(cur): - # k = key_of(it) - # if it is sel_item: - # cur_clean.append(it) - # continue - # if k is not None and k in seen_new: - # continue # drop duplicate; we'll reinsert at the target spot - # cur_clean.append(it) - - # # Compute insertion position: right AFTER the (possibly shifted) selected item - # if sel_item is not None: - # # find sel_item's new index in cur_clean - # try: - # pos_sel = cur_clean.index(sel_item) - # except ValueError: - # # Shouldn't happen, but fall back to end - # pos_sel = len(cur_clean) - 1 - # insert_pos = pos_sel + 1 - # else: - # insert_pos = len(cur_clean) # no selection -> append at end insert_pos = min(sel, len(cur) -1) cur_clean = cur # Build final list and selection @@ -330,6 +329,8 @@ def key_of(it: Any) -> Optional[str]: st["items"] = merged st["selected"] = new_sel + record_last_action(st,"add") + # print(f"gallery add: set sel {new_sel}") return gr.update(value=merged, selected_index=new_sel), st def _on_remove(self, state: Dict[str, Any], gallery) : @@ -342,8 +343,9 @@ def _on_remove(self, state: Dict[str, Any], gallery) : return gr.update(value=[], selected_index=None), st new_sel = min(sel, len(items) - 1) st["items"] = items; st["selected"] = new_sel - # return gr.update(value=items, selected_index=new_sel), st - return gr.update(value=items), st + record_last_action(st,"remove") + # print(f"gallery del: new sel {new_sel}") + return gr.update(value=items, selected_index=new_sel), st def _on_move(self, delta: int, state: Dict[str, Any], gallery) : st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) @@ -354,11 +356,15 @@ def _on_move(self, delta: int, state: Dict[str, Any], gallery) : return gr.update(value=items, selected_index=sel), st items[sel], items[j] = items[j], items[sel] st["items"] = items; st["selected"] = j + record_last_action(st,"move") + # print(f"gallery move: set sel {j}") return gr.update(value=items, selected_index=j), st def _on_clear(self, state: Dict[str, Any]) : st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode} - return gr.update(value=[], selected_index=0), st + record_last_action(st,"clear") + # print(f"Clear all") + return gr.update(value=[], selected_index=None), st def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : st = get_state(state); st["single"] = bool(to_single) @@ -382,30 +388,38 @@ def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False): if parent is not None: with parent: - col = self._build_ui() + col = self._build_ui(update_form) else: - col = self._build_ui() + col = self._build_ui(update_form) if not update_form: self._wire_events() return col - def _build_ui(self) -> gr.Column: + def _build_ui(self, update = False) -> gr.Column: with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col: self.container = col self.state = gr.State(dict(self._initial_state)) - self.gallery = gr.Gallery( - label=self.label, - value=self._initial_state["items"], - height=self.height, - columns=self.columns, - show_label=self.show_label, - preview= True, - # type="pil", - file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), - selected_index=self._initial_state["selected"], # server-side selection - ) + if update: + self.gallery = gr.update( + value=self._initial_state["items"], + selected_index=self._initial_state["selected"], # server-side selection + label=self.label, + show_label=self.show_label, + ) + else: + self.gallery = gr.Gallery( + value=self._initial_state["items"], + label=self.label, + height=self.height, + columns=self.columns, + show_label=self.show_label, + preview= True, + # type="pil", # very slow + file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), + selected_index=self._initial_state["selected"], # server-side selection + ) # One-line controls exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None @@ -418,10 +432,10 @@ def _build_ui(self) -> gr.Column: size="sm", min_width=1, ) - self.btn_remove = gr.Button("Remove", size="sm", min_width=1) + self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1) self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1) self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1) - self.btn_clear = gr.Button("Clear", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1) + self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1) return col @@ -430,14 +444,24 @@ def _wire_events(self): self.gallery.select( self._on_select, inputs=[self.state, self.gallery], - outputs=[self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) + self.gallery.upload( + self._on_upload, + inputs=[self.gallery, self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", ) # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) - self.gallery.change( + self.gallery.upload( self._on_gallery_change, inputs=[self.gallery, self.state], outputs=[self.gallery, self.state], + trigger_mode="always_last", ) # Add via UploadButton @@ -445,6 +469,7 @@ def _wire_events(self): self._on_add, inputs=[self.upload_btn, self.state, self.gallery], outputs=[self.gallery, self.state], + trigger_mode="always_last", ) # Remove selected @@ -452,6 +477,7 @@ def _wire_events(self): self._on_remove, inputs=[self.state, self.gallery], outputs=[self.gallery, self.state], + trigger_mode="always_last", ) # Reorder using selected index, keep same item selected @@ -459,11 +485,13 @@ def _wire_events(self): lambda st, gallery: self._on_move(-1, st, gallery), inputs=[self.state, self.gallery], outputs=[self.gallery, self.state], + trigger_mode="always_last", ) self.btn_right.click( lambda st, gallery: self._on_move(+1, st, gallery), inputs=[self.state, self.gallery], outputs=[self.gallery, self.state], + trigger_mode="always_last", ) # Clear all @@ -471,6 +499,7 @@ def _wire_events(self): self._on_clear, inputs=[self.state], outputs=[self.gallery, self.state], + trigger_mode="always_last", ) # ---------------- public API ---------------- diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 7ddf1eb6e..5dbd0af17 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -19,6 +19,7 @@ import subprocess import json from functools import lru_cache +os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from PIL import Image @@ -207,30 +208,62 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width return height, width, margin_top, margin_left -def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): - if fit_into_canvas == None: +def rescale_and_crop(img, w, h): + ow, oh = img.size + target_ratio = w / h + orig_ratio = ow / oh + + if orig_ratio > target_ratio: + # Crop width first + nw = int(oh * target_ratio) + img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh)) + else: + # Crop height first + nh = int(ow / target_ratio) + img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2)) + + return img.resize((w, h), Image.LANCZOS) + +def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): + if fit_into_canvas == None or fit_into_canvas == 2: # return image_height, image_width return canvas_height, canvas_width - if fit_into_canvas: + if fit_into_canvas == 1: scale1 = min(canvas_height / image_height, canvas_width / image_width) scale2 = min(canvas_width / image_height, canvas_height / image_width) scale = max(scale1, scale2) - else: + else: #0 or #2 (crop) scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2) new_height = round( image_height * scale / block_size) * block_size new_width = round( image_width * scale / block_size) * block_size return new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ): +def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16): + if fit_crop: + image = rescale_and_crop(image, canvas_width, canvas_height) + new_width, new_height = image.size + else: + image_width, image_height = image.size + new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size ) + image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + return image, new_height, new_width + +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ): if rm_background: session = new_session() output_list =[] for i, img in enumerate(img_list): width, height = img.size - - if fit_into_canvas: + if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: + if outpainting_dims is not None: + resized_image =img + elif img.size != (budget_width, budget_height): + resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) + else: + resized_image =img + elif fit_into_canvas == 1: white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 scale = min(budget_height / height, budget_width / width) new_height = int(height * scale) @@ -242,10 +275,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg resized_image = Image.fromarray(white_canvas) else: scale = (budget_height * budget_width / (height * width))**(1/2) - new_height = int( round(height * scale / 16) * 16) - new_width = int( round(width * scale / 16) * 16) + new_height = int( round(height * scale / block_size) * block_size) + new_width = int( round(width * scale / block_size) * block_size) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) - if rm_background and not (ignore_first and i == 0) : + if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, diff --git a/wgp.py b/wgp.py index 396e273e2..45cb5cfec 100644 --- a/wgp.py +++ b/wgp.py @@ -1,4 +1,5 @@ import os +os.environ["GRADIO_LANG"] = "en" # # os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding # os.environ["TORCH_LOGS"]= "recompiles" import torch._logging as tlog @@ -21,7 +22,7 @@ import importlib from shared.utils import notification_sound from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers -from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, get_video_frame +from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture @@ -61,7 +62,7 @@ target_mmgp_version = "3.6.0" WanGP_version = "8.4" -settings_version = 2.31 +settings_version = 2.33 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -220,7 +221,7 @@ def process_prompt_and_add_tasks(state, model_choice): return get_queue_table(queue) model_def = get_model_def(model_type) model_handler = get_model_handler(model_type) - image_outputs = inputs["image_mode"] == 1 + image_outputs = inputs["image_mode"] > 0 any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False) model_type = get_base_model_type(model_type) inputs["model_filename"] = model_filename @@ -370,7 +371,7 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("Mag Cache maximum number of steps is 50") return - if image_mode == 1: + if image_mode > 0: audio_prompt_type = "" if "B" in audio_prompt_type or "X" in audio_prompt_type: @@ -477,7 +478,8 @@ def process_prompt_and_add_tasks(state, model_choice): image_mask = None if "G" in video_prompt_type: - gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ") + if image_mode == 0: + gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ") else: denoising_strength = 1.0 if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]: @@ -1334,9 +1336,11 @@ def update_queue_data(queue): data = get_queue_table(queue) if len(data) == 0: - return gr.DataFrame(visible=False) + return gr.DataFrame(value=[], max_height=1) + elif len(data) == 1: + return gr.DataFrame(value=data, max_height= 83) else: - return gr.DataFrame(value=data, visible= True) + return gr.DataFrame(value=data, max_height= 1000) def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom" @@ -1371,6 +1375,12 @@ def _parse_args(): help="save proprocessed audio track with extract speakers for debugging or editing" ) + parser.add_argument( + "--debug-gen-form", + action="store_true", + help="View form generation / refresh time" + ) + parser.add_argument( "--vram-safety-coefficient", type=float, @@ -2070,8 +2080,6 @@ def fix_settings(model_type, ui_defaults): if image_prompt_type != None : if not isinstance(image_prompt_type, str): image_prompt_type = "S" if image_prompt_type == 0 else "SE" - # if model_type == "flf2v_720p" and not "E" in image_prompt_type: - # image_prompt_type = "SE" if settings_version <= 2: image_prompt_type = image_prompt_type.replace("G","") ui_defaults["image_prompt_type"] = image_prompt_type @@ -2091,10 +2099,7 @@ def fix_settings(model_type, ui_defaults): video_prompt_type = ui_defaults.get("video_prompt_type", "") - any_reference_image = model_def.get("reference_image", False) - if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"] or any_reference_image: - if not "I" in video_prompt_type: # workaround for settings corruption - video_prompt_type += "I" + if base_model_type in ["hunyuan"]: video_prompt_type = video_prompt_type.replace("I", "") @@ -2133,10 +2138,27 @@ def fix_settings(model_type, ui_defaults): del ui_defaults["tea_cache_start_step_perc"] ui_defaults["skip_steps_start_step_perc"] = tea_cache_start_step_perc + image_prompt_type = ui_defaults.get("image_prompt_type", "") + if len(image_prompt_type) > 0: + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed","") + image_prompt_type = filter_letters(image_prompt_type, image_prompt_types_allowed) + ui_defaults["image_prompt_type"] = image_prompt_type + + video_prompt_type = ui_defaults.get("video_prompt_type", "") + image_ref_choices_list = model_def.get("image_ref_choices", {}).get("choices", []) + if len(image_ref_choices_list)==0: + video_prompt_type = del_in_sequence(video_prompt_type, "IK") + else: + first_choice = image_ref_choices_list[0][1] + if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I" + if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K" + ui_defaults["video_prompt_type"] = video_prompt_type + model_handler = get_model_handler(base_model_type) if hasattr(model_handler, "fix_settings"): model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults) + def get_default_settings(model_type): def get_default_prompt(i2v): if i2v: @@ -3323,8 +3345,8 @@ def check(src, cond): if not all_letters(src, pos): return False if neg is not None and any_letters(src, neg): return False return True - image_outputs = configs.get("image_mode",0) == 1 - map_video_prompt = {"V" : "Control Video", ("VA", "U") : "Mask Video", "I" : "Reference Images"} + image_outputs = configs.get("image_mode",0) > 0 + map_video_prompt = {"V" : "Control Image" if image_outputs else "Control Video", ("VA", "U") : "Mask Image" if image_outputs else "Mask Video", "I" : "Reference Images"} map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"} map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"} video_other_prompts = [ v for s,v in map_image_prompt.items() if all_letters(video_image_prompt_type,s)] \ @@ -3571,9 +3593,27 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis end_time = time.time() # print(f"duration:{end_time-start_time:.1f}") - return results + return results -def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): +def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, block_size= 16, expand_scale = 2): + frame_width, frame_height = input_image.size + + if fit_canvas != None: + height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size) + input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS) + if input_mask is not None: + input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS) + + if expand_scale != 0 and input_mask is not None: + kernel_size = abs(expand_scale) + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + op_expand = cv2.dilate if expand_scale > 0 else cv2.erode + input_mask = np.array(input_mask) + input_mask = op_expand(input_mask, kernel, iterations=3) + input_mask = Image.fromarray(input_mask) + return input_image, input_mask + +def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions def mask_to_xyxy_box(mask): @@ -3615,6 +3655,9 @@ def mask_to_xyxy_box(mask): if len(video) == 0 or any_mask and len(mask_video) == 0: return None, None + if fit_crop and outpainting_dims != None: + fit_crop = False + fit_canvas = 0 if fit_canvas is not None else None frame_height, frame_width, _ = video[0].shape @@ -3629,7 +3672,7 @@ def mask_to_xyxy_box(mask): if outpainting_dims != None: final_height, final_width = height, width - height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) + height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) if any_mask: num_frames = min(len(video), len(mask_video)) @@ -3646,14 +3689,20 @@ def mask_to_xyxy_box(mask): # for frame_idx in range(num_frames): def prep_prephase(frame_idx): frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy() - frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) + if fit_crop: + frame = rescale_and_crop(frame, width, height) + else: + frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) frame = np.array(frame) if any_mask: if any_identity_mask: mask = np.full( (height, width, 3), 0, dtype= np.uint8) else: mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy() - mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) + if fit_crop: + mask = rescale_and_crop(mask, width, height) + else: + mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) mask = np.array(mask) if len(mask.shape) == 3 and mask.shape[2] == 3: @@ -3750,14 +3799,14 @@ def prep_prephase(frame_idx): return torch.stack(masked_frames), torch.stack(masks) if any_mask else None -def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, target_fps = 16, block_size = 16): +def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16): frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps) if len(frames_list) == 0: return None - if fit_canvas == None: + if fit_canvas == None or fit_crop: new_height = height new_width = width else: @@ -3775,7 +3824,10 @@ def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_can processed_frames_list = [] for frame in frames_list: frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8)) - frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) + if fit_crop: + frame = rescale_and_crop(frame, new_width, new_height) + else: + frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) processed_frames_list.append(frame) np_frames = [np.array(frame) for frame in processed_frames_list] @@ -4115,9 +4167,10 @@ def process_prompt_enhancer(prompt_enhancer, original_prompts, image_start, ori prompt_images = [] if "I" in prompt_enhancer: if image_start != None: - prompt_images.append(image_start) + prompt_images += image_start if original_image_refs != None: - prompt_images += original_image_refs[:1] + prompt_images += original_image_refs[:1] + prompt_images = [Image.open(img) if isinstance(img,str) else img for img in prompt_images] if len(original_prompts) == 0 and not "T" in prompt_enhancer: return None else: @@ -4223,7 +4276,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri original_image_refs = inputs["image_refs"] if original_image_refs is not None: original_image_refs = [ convert_image(tup[0]) for tup in original_image_refs ] - is_image = inputs["image_mode"] == 1 + is_image = inputs["image_mode"] > 0 seed = inputs["seed"] seed = set_seed(seed) enhanced_prompts = [] @@ -4367,7 +4420,7 @@ def remove_temp_filenames(temp_filenames_list): model_def = get_model_def(model_type) - is_image = image_mode == 1 + is_image = image_mode > 0 if is_image: if min_frames_if_references >= 1000: video_length = min_frames_if_references - 1000 @@ -4377,19 +4430,22 @@ def remove_temp_filenames(temp_filenames_list): batch_size = 1 temp_filenames_list = [] - if image_guide is not None and isinstance(image_guide, Image.Image): - video_guide = convert_image_to_video(image_guide) - temp_filenames_list.append(video_guide) - image_guide = None + convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False) + if convert_image_guide_to_video: + if image_guide is not None and isinstance(image_guide, Image.Image): + video_guide = convert_image_to_video(image_guide) + temp_filenames_list.append(video_guide) + image_guide = None - if image_mask is not None and isinstance(image_mask, Image.Image): - video_mask = convert_image_to_video(image_mask) - temp_filenames_list.append(video_mask) - image_mask = None + if image_mask is not None and isinstance(image_mask, Image.Image): + video_mask = convert_image_to_video(image_mask) + temp_filenames_list.append(video_mask) + image_mask = None + if model_def.get("no_background_removal", False): remove_background_images_ref = 0 + base_model_type = get_base_model_type(model_type) model_family = get_model_family(base_model_type) - fit_canvas = server_config.get("fit_canvas", 0) model_handler = get_model_handler(base_model_type) block_size = model_handler.get_vae_block_size(base_model_type) if hasattr(model_handler, "get_vae_block_size") else 16 @@ -4415,7 +4471,7 @@ def remove_temp_filenames(temp_filenames_list): return width, height = resolution.split("x") - width, height = int(width), int(height) + width, height = int(width) // block_size * block_size, int(height) // block_size * block_size default_image_size = (height, width) if slg_switch == 0: @@ -4530,39 +4586,25 @@ def remove_temp_filenames(temp_filenames_list): original_image_refs = image_refs # image_refs = None # nb_frames_positions= 0 - frames_to_inject = [] - any_background_ref = False - outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] # Output Video Ratio Priorities: # Source Video or Start Image > Control Video > Image Ref (background or positioned frames only) > UI Width, Height # Image Ref (non background and non positioned frames) are boxed in a white canvas in order to keep their own width/height ratio - - if image_refs is not None and len(image_refs) > 0: + frames_to_inject = [] + if image_refs is not None: frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions is not None and len(frames_positions)> 0 else [] frames_positions_list = frames_positions_list[:len(image_refs)] nb_frames_positions = len(frames_positions_list) - if nb_frames_positions > 0: - frames_to_inject = [None] * (max(frames_positions_list) + 1) - for i, pos in enumerate(frames_positions_list): - frames_to_inject[pos] = image_refs[i] - if video_guide == None and video_source == None and not "L" in image_prompt_type and (nb_frames_positions > 0 or "K" in video_prompt_type) : - from shared.utils.utils import get_outpainting_full_area_dimensions - w, h = image_refs[0].size - if outpainting_dims != None: - h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) - default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) - fit_canvas = None - # if there is a source video and a background image ref, the height/width ratio will need to be processed later by the code for the model (we dont know the source video dimensions at this point) - if len(image_refs) > nb_frames_positions: - any_background_ref = "K" in video_prompt_type - if remove_background_images_ref > 0: - send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) - os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") - from shared.utils.utils import resize_and_remove_background - # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or model_def.get("lock_image_refs_ratios", False)) ) # no fit for vace ref images as it is done later - update_task_thumbnails(task, locals()) - send_cmd("output") + any_background_ref = 0 + if "K" in video_prompt_type: + any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1 + + outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] + fit_canvas = server_config.get("fit_canvas", 0) + fit_crop = fit_canvas == 2 + if fit_crop and outpainting_dims is not None: + fit_crop = False + fit_canvas = 0 + joint_pass = boost ==1 #and profile != 1 and profile != 3 skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type) @@ -4632,6 +4674,9 @@ def remove_temp_filenames(temp_filenames_list): length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) current_video_length = min(current_video_length, length) + if image_guide is not None: + image_guide, image_mask = preprocess_image_with_mask(image_guide, image_mask, height, width, fit_canvas = None, block_size= block_size, expand_scale = mask_expand) + seed = set_seed(seed) torch.set_grad_enabled(False) @@ -4723,22 +4768,22 @@ def remove_temp_filenames(temp_filenames_list): return_latent_slice = None if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) - refresh_preview = {"image_guide" : None, "image_mask" : None} + refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} src_ref_images = image_refs image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: - new_height, new_width = calculate_new_dimensions(height, width, image_start.height, image_start.width, sample_fit_canvas, block_size = block_size) - image_start_tensor = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + image_start_tensor, new_height, new_width = calculate_dimensions_and_resize_image(image_start, height, width, sample_fit_canvas, fit_crop, block_size = block_size) + if fit_crop: refresh_preview["image_start"] = image_start_tensor image_start_tensor = convert_image_to_tensor(image_start_tensor) pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1) else: - if "L" in image_prompt_type: - refresh_preview["video_source"] = get_video_frame(video_source, 0) - prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size ) + prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size ) prefix_video = prefix_video.permute(3, 0, 1, 2) prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w + if fit_crop or "L" in image_prompt_type: refresh_preview["video_source"] = convert_tensor_to_image(prefix_video, 0) + new_height, new_width = prefix_video.shape[-2:] pre_video_guide = prefix_video[:, -reuse_frames:] pre_video_frame = convert_tensor_to_image(prefix_video[:, -1]) @@ -4752,10 +4797,11 @@ def remove_temp_filenames(temp_filenames_list): image_end_list= image_end if isinstance(image_end, list) else [image_end] if len(image_end_list) >= window_no: new_height, new_width = image_size - image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + image_end_tensor, _, _ = calculate_dimensions_and_resize_image(image_end_list[window_no-1], new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size) + # image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + refresh_preview["image_end"] = image_end_tensor image_end_tensor = convert_image_to_tensor(image_end_tensor) image_end_list= None - window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames) alignment_shift = source_video_frames_count if reset_control_aligment else 0 @@ -4768,7 +4814,8 @@ def remove_temp_filenames(temp_filenames_list): from models.wan.multitalk.multitalk import get_window_audio_embeddings # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) - + if vace: + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None if video_guide is not None: keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) @@ -4776,12 +4823,44 @@ def remove_temp_filenames(temp_filenames_list): raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] - if ltxv: + if vace: + context_scale = [ control_net_weight] + if "V" in video_prompt_type: + process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) + preprocess_type, preprocess_type2 = "raw", None + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSLCMU")): + if process_num == 0: + preprocess_type = process_map_video_guide.get(process_letter, "raw") + else: + preprocess_type2 = process_map_video_guide.get(process_letter, None) + status_info = "Extracting " + processes_names[preprocess_type] + extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) + if len(extra_process_list) == 1: + status_info += " and " + processes_names[extra_process_list[0]] + elif len(extra_process_list) == 2: + status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] + if preprocess_type2 is not None: + context_scale = [ control_net_weight /2, control_net_weight2 /2] + send_cmd("progress", [0, get_latest_status(state, status_info)]) + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 ) + if preprocess_type2 != None: + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + + if video_guide_processed != None: + if sample_fit_canvas != None: + image_size = video_guide_processed.shape[-3: -1] + sample_fit_canvas = None + refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy()) + if video_guide_processed2 != None: + refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] + if video_mask_processed != None: + refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) + elif ltxv: preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") status_info = "Extracting " + processes_names[preprocess_type] send_cmd("progress", [0, get_latest_status(state, status_info)]) # start one frame ealier to facilitate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) + src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) if src_video != None: src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) @@ -4798,15 +4877,14 @@ def remove_temp_filenames(temp_filenames_list): progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) + src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) if src_mask != None: refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) elif "R" in video_prompt_type: # sparse video to video src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True) - new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size) - src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + src_image, _, _ = calculate_dimensions_and_resize_image(src_image, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size) refresh_preview["video_guide"] = src_image src_video = convert_image_to_tensor(src_image).unsqueeze(1) if sample_fit_canvas != None: @@ -4814,7 +4892,7 @@ def remove_temp_filenames(temp_filenames_list): sample_fit_canvas = None else: # video to video - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) + video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps) if video_guide_processed is None: src_video = pre_video_guide else: @@ -4824,42 +4902,54 @@ def remove_temp_filenames(temp_filenames_list): src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) if pre_video_guide != None: src_video = torch.cat( [pre_video_guide, src_video], dim=1) - - if vace : - image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications - context_scale = [ control_net_weight] - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None - if "V" in video_prompt_type: - process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) - preprocess_type, preprocess_type2 = "raw", None - for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSLCMU")): - if process_num == 0: - preprocess_type = process_map_video_guide.get(process_letter, "raw") + elif image_guide is not None: + image_guide, new_height, new_width = calculate_dimensions_and_resize_image(image_guide, height, width, sample_fit_canvas, fit_crop, block_size = block_size) + image_size = (new_height, new_width) + refresh_preview["image_guide"] = image_guide + sample_fit_canvas = None + if image_mask is not None: + image_mask, _, _ = calculate_dimensions_and_resize_image(image_mask, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size) + refresh_preview["image_mask"] = image_mask + + if window_no == 1 and image_refs is not None and len(image_refs) > 0: + if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : + from shared.utils.utils import get_outpainting_full_area_dimensions + w, h = image_refs[0].size + if outpainting_dims != None: + h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) + image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) + sample_fit_canvas = None + if repeat_no == 1: + if fit_crop: + if any_background_ref == 2: + end_ref_position = len(image_refs) + elif any_background_ref == 1: + end_ref_position = nb_frames_positions + 1 else: - preprocess_type2 = process_map_video_guide.get(process_letter, None) - status_info = "Extracting " + processes_names[preprocess_type] - extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) - if len(extra_process_list) == 1: - status_info += " and " + processes_names[extra_process_list[0]] - elif len(extra_process_list) == 2: - status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] - if preprocess_type2 is not None: - context_scale = [ control_net_weight /2, control_net_weight2 /2] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 ) - if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + end_ref_position = nb_frames_positions + for i, img in enumerate(image_refs[:end_ref_position]): + image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0]) + refresh_preview["image_refs"] = image_refs + + if len(image_refs) > nb_frames_positions: + if remove_background_images_ref > 0: + send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) + # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested + image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], + remove_background_images_ref > 0, any_background_ref, + fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, + block_size=block_size, + outpainting_dims =outpainting_dims ) + refresh_preview["image_refs"] = image_refs + + if nb_frames_positions > 0: + frames_to_inject = [None] * (max(frames_positions_list) + 1) + for i, pos in enumerate(frames_positions_list): + frames_to_inject[pos] = image_refs[i] - if video_guide_processed != None: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy()) - if video_guide_processed2 != None: - refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] - if video_mask_processed != None: - refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) + if vace : frames_to_inject_parsed = frames_to_inject[aligned_guide_start_frame: aligned_guide_end_frame] + image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2], @@ -4868,7 +4958,6 @@ def remove_temp_filenames(temp_filenames_list): keep_video_guide_frames=keep_frames_parsed, start_frame = aligned_guide_start_frame, pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], - fit_into_canvas = sample_fit_canvas, inject_frames= frames_to_inject_parsed, outpainting_dims = outpainting_dims, any_background_ref = any_background_ref @@ -4931,9 +5020,9 @@ def set_header_text(txt): prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, frame_num= (current_video_length // latent_size)* latent_size + 1, batch_size = batch_size, - height = height, - width = width, - fit_into_canvas = fit_canvas == 1, + height = image_size[0], + width = image_size[1], + fit_into_canvas = fit_canvas, shift=flow_shift, sample_solver=sample_solver, sampling_steps=num_inference_steps, @@ -4990,6 +5079,8 @@ def set_header_text(txt): pre_video_frame = pre_video_frame, original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, + image_guide= image_guide, + image_mask= image_mask, ) except Exception as e: if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: @@ -5931,7 +6022,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if target == "settings": return inputs - image_outputs = inputs.get("image_mode",0) == 1 + image_outputs = inputs.get("image_mode",0) > 0 pop=[] if "force_fps" in inputs and len(inputs["force_fps"])== 0: @@ -5947,13 +6038,13 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["MMAudio_setting", "MMAudio_prompt", "MMAudio_neg_prompt"] video_prompt_type = inputs["video_prompt_type"] - if not base_model_type in ["t2v"]: + if not "G" in video_prompt_type: pop += ["denoising_strength"] if not (server_config.get("enhancer_enabled", 0) > 0 and server_config.get("enhancer_mode", 0) == 0): pop += ["prompt_enhancer"] - if not recammaster and not diffusion_forcing and not flux: + if model_def.get("model_modes", None) is None: pop += ["model_mode"] if not vace and not phantom and not hunyuan_video_custom: @@ -6075,6 +6166,18 @@ def image_to_ref_image_set(state, input_file_list, choice, target, target_name): gr.Info(f"Selected Image was copied to {target_name}") return file_list[choice] +def image_to_ref_image_guide(state, input_file_list, choice): + file_list, file_settings_list = get_file_list(state, input_file_list) + if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update(), gr.update() + ui_settings = get_current_model_settings(state) + gr.Info(f"Selected Image was copied to Control Image") + new_image = file_list[choice] + if ui_settings["image_mode"]==2: + return new_image, new_image + else: + return new_image, None + + def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation): gen = get_gen_info(state) @@ -6142,11 +6245,11 @@ def eject_video_from_gallery(state, input_file_list, choice): return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) def has_video_file_extension(filename): - extension = os.path.splitext(filename)[-1] + extension = os.path.splitext(filename)[-1].lower() return extension in [".mp4"] def has_image_file_extension(filename): - extension = os.path.splitext(filename)[-1] + extension = os.path.splitext(filename)[-1].lower() return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) @@ -6308,7 +6411,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw return configs, any_image_or_video def record_image_mode_tab(state, evt:gr.SelectData): - state["image_mode_tab"] = 0 if evt.index ==0 else 1 + state["image_mode_tab"] = evt.index def switch_image_mode(state): image_mode = state.get("image_mode_tab", 0) @@ -6316,7 +6419,18 @@ def switch_image_mode(state): ui_defaults = get_model_settings(state, model_type) ui_defaults["image_mode"] = image_mode - + video_prompt_type = ui_defaults.get("video_prompt_type", "") + model_def = get_model_def( model_type) + inpaint_support = model_def.get("inpaint_support", False) + if inpaint_support: + if image_mode == 1: + video_prompt_type = del_in_sequence(video_prompt_type, "VAG") + video_prompt_type = add_to_sequence(video_prompt_type, "KI") + elif image_mode == 2: + video_prompt_type = add_to_sequence(video_prompt_type, "VAG") + video_prompt_type = del_in_sequence(video_prompt_type, "KI") + ui_defaults["video_prompt_type"] = video_prompt_type + return str(time.time()) def load_settings_from_file(state, file_path): @@ -6349,6 +6463,7 @@ def load_settings_from_file(state, file_path): def save_inputs( target, + image_mask_guide, lset_name, image_mode, prompt, @@ -6434,13 +6549,18 @@ def save_inputs( state, ): - - # if state.get("validate_success",0) != 1: - # return + model_filename = state["model_filename"] model_type = state["model_type"] + if image_mask_guide is not None and image_mode == 2: + if "background" in image_mask_guide: + image_guide = image_mask_guide["background"] + if "layers" in image_mask_guide and len(image_mask_guide["layers"])>0: + image_mask = image_mask_guide["layers"][0] + image_mask_guide = None inputs = get_function_arguments(save_inputs, locals()) inputs.pop("target") + inputs.pop("image_mask_guide") cleaned_inputs = prepare_inputs_dict(target, inputs) if target == "settings": defaults_filename = get_settings_file_name(model_type) @@ -6544,11 +6664,16 @@ def change_model(state, model_choice): return header -def fill_inputs(state): +def get_current_model_settings(state): model_type = state["model_type"] - ui_defaults = get_model_settings(state, model_type) + ui_defaults = get_model_settings(state, model_type) if ui_defaults == None: ui_defaults = get_default_settings(model_type) + set_model_settings(state, model_type, ui_defaults) + return ui_defaults + +def fill_inputs(state): + ui_defaults = get_current_model_settings(state) return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults) @@ -6623,7 +6748,9 @@ def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_ image_prompt_type = del_in_sequence(image_prompt_type, "VLTS") image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) any_video_source = len(filter_letters(image_prompt_type, "VL"))>0 - end_visible = any_letters(image_prompt_type, "SVL") + model_def = get_model_def(state["model_type"]) + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") + end_visible = "E" in image_prompt_types_allowed and any_letters(image_prompt_type, "SVL") return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible) def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox): @@ -6654,7 +6781,7 @@ def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_ visible= "A" in video_prompt_type model_type = state["model_type"] model_def = get_model_def(model_type) - image_outputs = image_mode == 1 + image_outputs = image_mode > 0 return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible ) def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_type_video_guide): @@ -6663,20 +6790,22 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t return video_prompt_type def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode): - video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMGUV") + video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type model_type = state["model_type"] base_model_type = get_base_model_type(model_type) mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type model_def = get_model_def(model_type) - image_outputs = image_mode == 1 + image_outputs = image_mode > 0 vace= test_vace_module(model_type) keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible) def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt): - video_prompt_type = del_in_sequence(video_prompt_type, "RGUVQKI") + model_def = get_model_def(state["model_type"]) + guide_custom_choices = model_def.get("guide_custom_choices",{}) + video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter","")) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) control_video_visible = "V" in video_prompt_type ref_images_visible = "I" in video_prompt_type @@ -6711,7 +6840,7 @@ def get_image_end_label(multi_prompts_gen_type): return "Images as ending points for new Videos in the Generation Queue" if multi_prompts_gen_type == 0 else "Images as ending points for each new Window of the same Video Generation" def refresh_prompt_labels(multi_prompts_gen_type, image_mode): - prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode == 1) + prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode > 0) return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label), gr.update(label=get_image_end_label(multi_prompts_gen_type)) def show_preview_column_modal(state, column_no): @@ -7032,7 +7161,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename lock_inference_steps = model_def.get("lock_inference_steps", False) - model_reference_image = model_def.get("reference_image", False) any_tea_cache = model_def.get("tea_cache", False) any_mag_cache = model_def.get("mag_cache", False) recammaster = base_model_type in ["recam_1.3B"] @@ -7075,18 +7203,22 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl if not v2i_switch_supported and not image_outputs: image_mode_value = 0 else: - image_outputs = image_mode_value == 1 + image_outputs = image_mode_value > 0 + inpaint_support = model_def.get("inpaint_support", False) image_mode = gr.Number(value =image_mode_value, visible = False) - - with gr.Tabs(visible = v2i_switch_supported, selected= "t2i" if image_mode_value == 1 else "t2v" ) as image_mode_tabs: - with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab"): + image_mode_tab_selected= "t2i" if image_mode_value == 1 else ("inpaint" if image_mode_value == 2 else "t2v") + with gr.Tabs(visible = v2i_switch_supported or inpaint_support, selected= image_mode_tab_selected ) as image_mode_tabs: + with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab", visible = v2i_switch_supported) as tab_t2v: pass with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"): pass + with gr.Tab("Image Inpainting", id = "inpaint", elem_classes="compact_tab", visible=inpaint_support) as tab_inpaint: + pass image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") model_mode_choices = model_def.get("model_modes", None) with gr.Column(visible= len(image_prompt_types_allowed)> 0 or model_mode_choices is not None) as image_prompt_column: + # Video Continue / Start Frame / End Frame image_prompt_type_value= ui_defaults.get("image_prompt_type","") image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) image_prompt_type_choices = [] @@ -7123,192 +7255,167 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=ui_defaults.get("model_mode", model_mode_choices["default"]), label=model_mode_choices["label"], visible=True) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) - with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column: + any_control_video = any_control_image = False + guide_preprocessing = model_def.get("guide_preprocessing", None) + mask_preprocessing = model_def.get("mask_preprocessing", None) + guide_custom_choices = model_def.get("guide_custom_choices", None) + image_ref_choices = model_def.get("image_ref_choices", None) + + # with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or (flux or qwen ) and model_reference_image and image_mode_value >=1) as video_prompt_column: + with gr.Column(visible= guide_preprocessing is not None or mask_preprocessing is not None or guide_custom_choices is not None or image_ref_choices is not None) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) - any_control_video = True - any_control_image = image_outputs - with gr.Row(): - if t2v: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("Use Text Prompt Only", ""), - ("Image to Image guided by Text Prompt" if image_outputs else "Video to Video guided by Text Prompt", "GUV"), - ], - value=filter_letters(video_prompt_type_value, "GUV"), - label="Video to Video", scale = 2, show_label= False, visible= True - ) - elif vace : + with gr.Row(visible = image_mode_value!=2) as guide_selection_row: + # Control Video Preprocessing + if guide_preprocessing is None: + video_prompt_type_video_guide = gr.Dropdown(choices=[("","")], value="", label="Control Video", scale = 2, visible= False, show_label= True, ) + else: pose_label = "Pose" if image_outputs else "Motion" + guide_preprocessing_labels_all = { + "": "No Control Video", + "UV": "Keep Control Video Unchanged", + "PV": f"Transfer Human {pose_label}", + "PMV": f"Transfer Human {pose_label}", + "DV": "Transfer Depth", + "EV": "Transfer Canny Edges", + "SV": "Transfer Shapes", + "LV": "Transfer Flow", + "CV": "Recolorize", + "MV": "Perform Inpainting", + "V": "Use Vace raw format", + "PDV": f"Transfer Human {pose_label} & Depth", + "PSV": f"Transfer Human {pose_label} & Shapes", + "PLV": f"Transfer Human {pose_label} & Flow" , + "DSV": "Transfer Depth & Shapes", + "DLV": "Transfer Depth & Flow", + "SLV": "Transfer Shapes & Flow", + } + guide_preprocessing_choices = [] + guide_preprocessing_labels = guide_preprocessing.get("labels", {}) + for process_type in guide_preprocessing["selection"]: + process_label = guide_preprocessing_labels.get(process_type, None) + process_label = guide_preprocessing_labels_all.get(process_type,process_type) if process_label is None else process_label + if image_outputs: process_label = process_label.replace("Video", "Image") + guide_preprocessing_choices.append( (process_label, process_type) ) + + video_prompt_type_video_guide_label = guide_preprocessing.get("label", "Control Video Process") + if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("No Control Image" if image_outputs else "No Control Video", ""), - ("Keep Control Image Unchanged" if image_outputs else "Keep Control Video Unchanged", "UV"), - (f"Transfer Human {pose_label}" , "PV"), - ("Transfer Depth", "DV"), - ("Transfer Shapes", "SV"), - ("Transfer Flow", "LV"), - ("Recolorize", "CV"), - ("Perform Inpainting", "MV"), - ("Use Vace raw format", "V"), - (f"Transfer Human {pose_label} & Depth", "PDV"), - (f"Transfer Human {pose_label} & Shapes", "PSV"), - (f"Transfer Human {pose_label} & Flow", "PLV"), - ("Transfer Depth & Shapes", "DSV"), - ("Transfer Depth & Flow", "DLV"), - ("Transfer Shapes & Flow", "SLV"), - ], - value=filter_letters(video_prompt_type_value, "PDSLCMGUV"), - label="Control Image Process" if image_outputs else "Control Video Process", scale = 2, visible= True, show_label= True, - ) - elif ltxv: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("No Control Video", ""), - ("Transfer Human Motion", "PV"), - ("Transfer Depth", "DV"), - ("Transfer Canny Edges", "EV"), - ("Use LTXV raw format", "V"), - ], - value=filter_letters(video_prompt_type_value, "PDEV"), - label="Control Video Process", scale = 2, visible= True, show_label= True, + guide_preprocessing_choices, + value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ), + label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, ) + any_control_video = True + any_control_image = image_outputs - elif hunyuan_video_custom_edit: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("Inpaint Control Image" if image_outputs else "Inpaint Control Video", "MV"), - ("Transfer Human Motion", "PMV"), - ], - value=filter_letters(video_prompt_type_value, "PDSLCMUV"), - label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True, - ) - elif recammaster: - video_prompt_type_video_guide = gr.Dropdown(value="UV", choices = [("Control Video","UV")], visible=False) + # Alternate Control Video Preprocessing / Options + if guide_custom_choices is None: + video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 2 ) else: - any_control_video = False - any_control_image = False - video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False) - - if infinitetalk: + video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process") + if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image") + video_prompt_type_video_guide_alt_choices = [(label.replace("Video", "Image") if image_outputs else label, value) for label,value in guide_custom_choices["choices"] ] video_prompt_type_video_guide_alt = gr.Dropdown( - choices=[ - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"), - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"), - ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"), - ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"), - ], - value=filter_letters(video_prompt_type_value, "RGUVQKI"), - label="Video to Video", scale = 3, visible= True, show_label= False, - ) - any_control_video = any_control_image = True - else: - video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False) - - # video_prompt_video_guide_trigger = gr.Text(visible=False, value="") - if t2v: - video_prompt_type_video_mask = gr.Dropdown(value = "", choices = [""], visible = False) - elif hunyuan_video_custom_edit: - video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ], - value= filter_letters(video_prompt_type_value, "NA"), - visible= "V" in video_prompt_type_value, - label="Area Processed", scale = 2, show_label= True, - ) - elif ltxv: - video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Whole Frame", ""), - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ("Masked Area, rest Inpainted", "XA"), - ("Non Masked Area, rest Inpainted", "XNA"), - ], - value= filter_letters(video_prompt_type_value, "XNA"), - visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value, - label="Area Processed", scale = 2, show_label= True, + choices= video_prompt_type_video_guide_alt_choices, + value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ), + visible = guide_custom_choices.get("visible", True), + label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = 2 ) + any_control_video = True + any_control_image = image_outputs + + # Control Mask Preprocessing + if mask_preprocessing is None: + video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 2, visible= False, show_label= True, ) else: + mask_preprocessing_labels_all = { + "": "Whole Frame", + "A": "Masked Area", + "NA": "Non Masked Area", + "XA": "Masked Area, rest Inpainted", + "XNA": "Non Masked Area, rest Inpainted", + "YA": "Masked Area, rest Depth", + "YNA": "Non Masked Area, rest Depth", + "WA": "Masked Area, rest Shapes", + "WNA": "Non Masked Area, rest Shapes", + "ZA": "Masked Area, rest Flow", + "ZNA": "Non Masked Area, rest Flow" + } + + mask_preprocessing_choices = [] + mask_preprocessing_labels = guide_preprocessing.get("labels", {}) + for process_type in mask_preprocessing["selection"]: + process_label = mask_preprocessing_labels.get(process_type, None) + process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label + mask_preprocessing_choices.append( (process_label, process_type) ) + + video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed") video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Whole Frame", ""), - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ("Masked Area, rest Inpainted", "XA"), - ("Non Masked Area, rest Inpainted", "XNA"), - ("Masked Area, rest Depth", "YA"), - ("Non Masked Area, rest Depth", "YNA"), - ("Masked Area, rest Shapes", "WA"), - ("Non Masked Area, rest Shapes", "WNA"), - ("Masked Area, rest Flow", "ZA"), - ("Non Masked Area, rest Flow", "ZNA"), - ], - value= filter_letters(video_prompt_type_value, "XYZWNA"), - visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom and not ltxv, - label="Area Processed", scale = 2, show_label= True, - ) - image_ref_choices = model_def.get("image_ref_choices", None) - if image_ref_choices is not None: - video_prompt_type_image_refs = gr.Dropdown( - choices= image_ref_choices["choices"], - value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), - visible = True, - label=image_ref_choices["label"], show_label= True, scale = 2 - ) - elif t2v: - video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False) - elif vace: - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("None", ""), - ("Inject only People / Objects", "I"), - ("Inject Landscape and then People / Objects", "KI"), - ("Inject Frames and then People / Objects", "FI"), - ], - value=filter_letters(video_prompt_type_value, "KFI"), - visible = True, - label="Reference Images", show_label= True, scale = 2 - ) - elif standin: # and not vace - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("No Reference Image", ""), - ("Reference Image is a Person Face", "I"), - ], - value=filter_letters(video_prompt_type_value, "I"), - visible = True, - show_label=False, - label="Reference Image", scale = 2 + mask_preprocessing_choices, + value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), + label= video_prompt_type_video_mask_label , scale = 2, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True), + show_label= True, ) - elif (flux or qwen) and model_reference_image: + + # Image Refs Selection + if image_ref_choices is None: video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("None", ""), - ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"), - ("Conditional Images are People / Objects", "I"), - ], - value=filter_letters(video_prompt_type_value, "KI"), - visible = True, - show_label=False, - label="Reference Images Combination Method", scale = 2 + # choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")], + choices=[ ("None", ""),], + value=filter_letters(video_prompt_type_value, ""), + visible = False, + label="Start / Reference Images", scale = 2 ) + any_reference_image = False else: + any_reference_image = True video_prompt_type_image_refs = gr.Dropdown( - choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")], - value=filter_letters(video_prompt_type_value, "KI"), - visible = False, - label="Start / Reference Images", scale = 2 + choices= image_ref_choices["choices"], + value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), + visible = image_ref_choices.get("visible", True), + label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2 ) - image_guide = gr.Image(label= "Control Image", height = gallery_height, type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None)) + + image_guide = gr.Image(label= "Control Image", height = gallery_height, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None)) video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + if image_mode_value == 2 and inpaint_support: + image_guide_value = ui_defaults.get("image_guide", None) + image_mask_value = ui_defaults.get("image_mask", None) + if image_guide_value is None: + image_mask_guide_value = None + else: + def rgb_bw_to_rgba_mask(img, thresh=127): + a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha + out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent + out.putalpha(a) # white where alpha=255 + return out + + image_mask_value = rgb_bw_to_rgba_mask(image_mask_value) + image_mask_guide_value = { "background" : image_guide_value, "composite" : None, "layers": [image_mask_value] } + + image_mask_guide = gr.ImageEditor( + label="Control Image to be Inpainted", + value = image_mask_guide_value, + type='pil', + sources=["upload", "webcam"], + image_mode='RGB', + layers=False, + brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"), + # fixed_canvas= True, + width=800, + height=800, + # transforms=None, + # interactive=True, + elem_id="img_editor", + visible= True + ) + any_control_image = True + else: + image_mask_guide = gr.ImageEditor(value = None, visible = False, elem_id="img_editor") - denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False) + + denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label=f"Denoising Strength (the Lower the Closer to the Control {'Image' if image_outputs else 'Video'})", visible = "G" in video_prompt_type_value, show_reset_button= False) keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False) keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last @@ -7325,11 +7432,10 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False) video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False) any_image_mask = image_outputs and vace - image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None)) + image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None)) video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("video_mask", None)) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) - any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "") @@ -7424,10 +7530,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl visible= True, show_label= not on_demand_prompt_enhancer, ) with gr.Row(): - if server_config.get("fit_canvas", 0) == 1: - label = "Max Resolution (As it maybe less depending on video width / height ratio)" + fit_canvas = server_config.get("fit_canvas", 0) + if fit_canvas == 1: + label = "Outer Box Resolution (one dimension may be less to preserve video W/H ratio)" + elif fit_canvas == 2: + label = "Output Resolution (Input Images wil be Cropped if the W/H ratio is different)" else: - label = "Max Resolution (Pixels will be reallocated depending on the output width / height ratio)" + label = "Resolution Budget (Pixels will be reallocated to preserve Inputs W/H ratio)" current_resolution_choice = ui_defaults.get("resolution","832x480") if update_form or last_resolution is None else last_resolution model_resolutions = model_def.get("resolutions", None) resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions) @@ -7751,7 +7860,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai sliding_window_defaults = model_def.get("sliding_window_defaults", {}) sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size") sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") - sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)") + sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True) sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7902,7 +8011,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai add_to_queue_trigger = gr.Text(visible = False) with gr.Column(visible= False) as current_gen_column: - with gr.Accordion("Preview", open=False) as queue_accordion: + with gr.Accordion("Preview", open=False): preview = gr.Image(label="Preview", height=200, show_label= False) preview_trigger = gr.Text(visible= False) gen_info = gr.HTML(visible=False, min_height=1) @@ -7947,8 +8056,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, - NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, - min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] + image_start_extra + image_end_extra + image_refs_extra # presets_column, + NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs, + min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn, tab_inpaint, tab_t2v] + image_start_extra + image_end_extra + image_refs_extra # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -7970,10 +8079,10 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" ) image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] ) # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) - video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col]) - video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) - video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ]) - video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) + video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden") + video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand], show_progress="hidden") + video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs_row, denoising_strength ], show_progress="hidden") + video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand], show_progress="hidden") video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden") video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) @@ -7984,8 +8093,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) - gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab]) - preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) + gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab], show_progress="hidden") + preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview], show_progress="hidden") PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] ) def refresh_status_async(state, progress=gr.Progress()): gen = get_gen_info(state) @@ -8017,7 +8126,9 @@ def activate_status(state): output_trigger.change(refresh_gallery, inputs = [state], - outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn]) + outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn], + show_progress="hidden" + ) preview_column_no.input(show_preview_column_modal, inputs=[state, preview_column_no], outputs=[preview_column_no, modal_image_display, modal_container]) @@ -8033,7 +8144,8 @@ def activate_status(state): gr.on( triggers=[video_info_extract_settings_btn.click, video_info_extract_image_settings_btn.click], fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8042,7 +8154,8 @@ def activate_status(state): prompt_enhancer_btn.click(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8050,7 +8163,8 @@ def activate_status(state): saveform_trigger.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8065,14 +8179,14 @@ def activate_status(state): video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] ) video_info_to_start_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) video_info_to_end_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] ) - video_info_to_image_guide_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_guide, gr.State("Control Image")], outputs = [image_guide] ) + video_info_to_image_guide_btn.click(fn=image_to_ref_image_guide, inputs =[state, output, last_choice], outputs = [image_guide, image_mask_guide] ) video_info_to_image_mask_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_mask, gr.State("Image Mask")], outputs = [image_mask] ) video_info_to_reference_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) video_info_remux_audio_btn.click(fn=remux_audio, inputs =[state, output, last_choice, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation, PP_custom_audio], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) - confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( + confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt], show_progress="hidden",).then( fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None).then( @@ -8087,7 +8201,8 @@ def activate_status(state): lset_name.select(fn=update_lset_type, inputs=[state, lset_name], outputs=save_lset_prompt_drop) export_settings_from_file_btn.click(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8104,7 +8219,8 @@ def activate_status(state): image_mode_tabs.select(fn=record_image_mode_tab, inputs=[state], outputs= None ).then(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8112,7 +8228,8 @@ def activate_status(state): settings_file.upload(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8126,17 +8243,20 @@ def activate_status(state): refresh_form_trigger.change(fn= fill_inputs, inputs=[state], - outputs=gen_inputs + extra_inputs + outputs=gen_inputs + extra_inputs, + show_progress= "full" if args.debug_gen_form else "hidden", ).then(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ) model_family.input(fn=change_model_family, inputs=[state, model_family], outputs= [model_choice]) model_choice.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8145,7 +8265,8 @@ def activate_status(state): outputs= [header] ).then(fn= fill_inputs, inputs=[state], - outputs=gen_inputs + extra_inputs + outputs=gen_inputs + extra_inputs, + show_progress="full" if args.debug_gen_form else "hidden", ).then(fn= preload_model_when_switching, inputs=[state], outputs=[gen_status]) @@ -8154,13 +8275,15 @@ def activate_status(state): generate_trigger.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs= queue_df + outputs= queue_df, + show_progress="hidden", ).then(fn=prepare_generate_video, inputs= [state], outputs= [generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row] @@ -8170,10 +8293,12 @@ def activate_status(state): ).then( fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), inputs=[state], - outputs=[queue_accordion] + outputs=[queue_accordion], + show_progress="hidden", ).then(fn=process_tasks, inputs= [state], outputs= [preview_trigger, output_trigger], + show_progress="hidden", ).then(finalize_generation, inputs= [state], outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] @@ -8280,17 +8405,20 @@ def activate_status(state): # gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt, add_to_queue_trigger.change(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs=queue_df + outputs=queue_df, + show_progress="hidden", ).then( fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), inputs=[state], - outputs=[queue_accordion] + outputs=[queue_accordion], + show_progress="hidden", ).then( fn=update_status, inputs = [state], @@ -8302,8 +8430,8 @@ def activate_status(state): outputs=[modal_container] ) - return ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + return ( state, loras_choices, lset_name, resolution, refresh_form_trigger, + # video_guide, image_guide, video_mask, image_mask, image_refs, ) @@ -8339,8 +8467,9 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice fit_canvas_choice = gr.Dropdown( choices=[ - ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0), - ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1), + ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be Resized to match this pixels Budget, output video height or width may exceed the requested dimensions )", 0), + ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be Resized to fit into these dimensions, the output video may be smaller)", 1), + ("Dimensions correspond to the Output Width and Height (as the Prompt Image/Video will be Cropped to fit exactly these dimensions)", 2), ], value= server_config.get("fit_canvas", 0), label="Generated Video Dimensions when Prompt contains an Image or a Video", @@ -9231,9 +9360,17 @@ def create_ui(): console.log('Events dispatched for column:', index); } }; - console.log('sendColIndex function attached to window'); - } + + // cancel wheel usage inside image editor + const hit = n => n?.id === "img_editor" || n?.classList?.contains("wheel-pass"); + addEventListener("wheel", e => { + const path = e.composedPath?.() || (() => { let a=[],n=e.target; for(;n;n=n.parentNode||n.host) a.push(n); return a; })(); + if (path.some(hit)) e.stopImmediatePropagation(); + }, { capture: true, passive: true }); + + } + """ if server_config.get("display_stats", 0) == 1: from shared.utils.stats import SystemStatsApp @@ -9264,13 +9401,13 @@ def create_ui(): stats_element = stats_app.get_gradio_element() with gr.Row(): - ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + ( state, loras_choices, lset_name, resolution, refresh_form_trigger + # video_guide, image_guide, video_mask, image_mask, image_refs, ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: - matanyone_app.display(main_tabs, tab_state, server_config, video_guide, image_guide, video_mask, image_mask, image_refs) + matanyone_app.display(main_tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings) #, video_guide, image_guide, video_mask, image_mask, image_refs) if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) From e7c08d12c846ddf4817013e1b112b1ecf183ada3 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 10 Sep 2025 01:47:55 +0200 Subject: [PATCH 024/155] fixed unwanted discontinuity with at the end of first sliding window with InfiniteTalk --- models/wan/any2video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index bb91dc66b..1e62c3511 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -539,7 +539,7 @@ def generate(self, new_shot = "Q" in video_prompt_type else: if pre_video_frame is None: - new_shot = True + new_shot = "Q" in video_prompt_type else: if input_ref_images is None: input_ref_images, new_shot = [pre_video_frame], False From 9fa267087b2dfdba651fd173325537f031edf91d Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 11 Sep 2025 21:23:05 +0200 Subject: [PATCH 025/155] Flux Festival --- README.md | 9 +++- defaults/flux_dev_umo.json | 24 ++++++++++ defaults/flux_dev_uso.json | 2 +- defaults/flux_srpo.json | 15 ++++++ defaults/flux_srpo_uso.json | 17 +++++++ models/flux/flux_handler.py | 22 +++++++++ models/flux/flux_main.py | 77 +++++++++++++++++++++++++------ models/flux/model.py | 15 ++++++ models/flux/sampling.py | 57 +++++++++++++++++++++-- models/flux/util.py | 32 +++++++++++++ models/qwen/pipeline_qwenimage.py | 19 +++++--- models/qwen/qwen_handler.py | 1 + models/wan/any2video.py | 2 +- models/wan/wan_handler.py | 1 + wgp.py | 62 ++++++++++++++++--------- 15 files changed, 305 insertions(+), 50 deletions(-) create mode 100644 defaults/flux_dev_umo.json create mode 100644 defaults/flux_srpo.json create mode 100644 defaults/flux_srpo_uso.json diff --git a/README.md b/README.md index fc3d76c84..d33b6dc4b 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### September 5 2025: WanGP v8.5 - Wanna be a Cropper or a Painter ? +### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ? I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof. @@ -38,6 +38,13 @@ Doing more sophisticated thing Vace Image Editor works very well too: try Image For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*" +**update 8.55**: Flux Festival +- **Inpainting Mode** also added for *Flux Kontext* +- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available +- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768 + +Good luck with finding your way through all the Flux models names ! + ### September 5 2025: WanGP v8.4 - Take me to Outer Space You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie. diff --git a/defaults/flux_dev_umo.json b/defaults/flux_dev_umo.json new file mode 100644 index 000000000..57164bb4c --- /dev/null +++ b/defaults/flux_dev_umo.json @@ -0,0 +1,24 @@ +{ + "model": { + "name": "Flux 1 Dev UMO 12B", + "architecture": "flux", + "description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.", + "URLs": "flux", + "flux-model": "flux-dev-umo", + "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"], + "resolutions": [ ["1024x1024 (1:1)", "1024x1024"], + ["768x1024 (3:4)", "768x1024"], + ["1024x768 (4:3)", "1024x768"], + ["512x1024 (1:2)", "512x1024"], + ["1024x512 (2:1)", "1024x512"], + ["768x768 (1:1)", "768x768"], + ["768x512 (3:2)", "768x512"], + ["512x768 (2:3)", "512x768"]] + }, + "prompt": "the man is wearing a hat", + "embedded_guidance_scale": 4, + "resolution": "768x768", + "batch_size": 1 +} + + \ No newline at end of file diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json index 0cd7b828d..806dd7e02 100644 --- a/defaults/flux_dev_uso.json +++ b/defaults/flux_dev_uso.json @@ -2,7 +2,7 @@ "model": { "name": "Flux 1 Dev USO 12B", "architecture": "flux", - "description": "FLUX.1 Dev USO is a model specialized to Edit Images with a specialization in Style Transfers (up to two).", + "description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).", "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], "URLs": "flux", "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"], diff --git a/defaults/flux_srpo.json b/defaults/flux_srpo.json new file mode 100644 index 000000000..59f07c617 --- /dev/null +++ b/defaults/flux_srpo.json @@ -0,0 +1,15 @@ +{ + "model": { + "name": "Flux 1 SRPO Dev 12B", + "architecture": "flux", + "description": "By fine-tuning the FLUX.1.dev model with optimized denoising and online reward adjustment, SRPO improves its human-evaluated realism and aesthetic quality by over 3x.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_quanto_bf16_int8.safetensors" + ], + "flux-model": "flux-dev" + }, + "prompt": "draw a hat", + "resolution": "1024x1024", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/flux_srpo_uso.json b/defaults/flux_srpo_uso.json new file mode 100644 index 000000000..ddfe50d63 --- /dev/null +++ b/defaults/flux_srpo_uso.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Flux 1 SRPO USO 12B", + "architecture": "flux", + "description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process", + "modules": [ "flux_dev_uso"], + "URLs": "flux_srpo", + "loras": "flux_dev_uso", + "flux-model": "flux-dev-uso" + }, + "prompt": "the man is wearing a hat", + "embedded_guidance_scale": 4, + "resolution": "1024x1024", + "batch_size": 1 +} + + \ No newline at end of file diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index c468d5a9a..808369ff6 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -13,6 +13,7 @@ def query_model_def(base_model_type, model_def): flux_schnell = flux_model == "flux-schnell" flux_chroma = flux_model == "flux-chroma" flux_uso = flux_model == "flux-dev-uso" + flux_umo = flux_model == "flux-dev-umo" flux_kontext = flux_model == "flux-dev-kontext" extra_model_def = { @@ -35,6 +36,7 @@ def query_model_def(base_model_type, model_def): } if flux_kontext: + extra_model_def["inpaint_support"] = True extra_model_def["image_ref_choices"] = { "choices": [ ("None", ""), @@ -43,6 +45,15 @@ def query_model_def(base_model_type, model_def): ], "letters_filter": "KI", } + extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape" + elif flux_umo: + extra_model_def["image_ref_choices"] = { + "choices": [ + ("Conditional Images are People / Objects", "I"), + ], + "letters_filter": "I", + "visible": False + } extra_model_def["lock_image_refs_ratios"] = True @@ -131,10 +142,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): video_prompt_type = video_prompt_type.replace("I", "KI") ui_defaults["video_prompt_type"] = video_prompt_type + if settings_version < 2.34: + ui_defaults["denoising_strength"] = 1. + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): flux_model = model_def.get("flux-model", "flux-dev") flux_uso = flux_model == "flux-dev-uso" + flux_umo = flux_model == "flux-dev-umo" flux_kontext = flux_model == "flux-dev-kontext" ui_defaults.update({ "embedded_guidance": 2.5, @@ -143,5 +158,12 @@ def update_default_settings(base_model_type, model_def, ui_defaults): if flux_kontext or flux_uso: ui_defaults.update({ "video_prompt_type": "KI", + "denoising_strength": 1., + }) + elif flux_umo: + ui_defaults.update({ + "video_prompt_type": "I", + "remove_background_images_ref": 0, }) + diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 4d7c67d41..686371129 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -23,6 +23,35 @@ ) from PIL import Image +def preprocess_ref(raw_image: Image.Image, long_size: int = 512): + # 获取原始图像的宽度和高度 + image_w, image_h = raw_image.size + + # 计算长边和短边 + if image_w >= image_h: + new_w = long_size + new_h = int((long_size / image_w) * image_h) + else: + new_h = long_size + new_w = int((long_size / image_h) * image_w) + + # 按新的宽高进行等比例缩放 + raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) + target_w = new_w // 16 * 16 + target_h = new_h // 16 * 16 + + # 计算裁剪的起始坐标以实现中心裁剪 + left = (new_w - target_w) // 2 + top = (new_h - target_h) // 2 + right = left + target_w + bottom = top + target_h + + # 进行中心裁剪 + raw_image = raw_image.crop((left, top, right, bottom)) + + # 转换为 RGB 模式 + raw_image = raw_image.convert("RGB") + return raw_image def stitch_images(img1, img2): # Resize img2 to match img1's height @@ -67,7 +96,7 @@ def __init__( # self.name= "flux-schnell" source = model_def.get("source", None) self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) - + self.model_def = model_def self.vae = load_ae(self.name, device=torch_device) siglip_processor = siglip_model = feature_embedder = None @@ -109,10 +138,12 @@ def __init__( def generate( self, seed: int | None = None, - input_prompt: str = "replace the logo with the text 'Black Forest Labs'", + input_prompt: str = "replace the logo with the text 'Black Forest Labs'", n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, + image_guide= None, + image_mask= None, width= 832, height=480, embedded_guidance_scale: float = 2.5, @@ -123,7 +154,8 @@ def generate( batch_size = 1, video_prompt_type = "", joint_pass = False, - image_refs_relative_size = 100, + image_refs_relative_size = 100, + denoising_strength = 1., **bbargs ): if self._interrupt: @@ -132,8 +164,16 @@ def generate( if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" device="cuda" flux_dev_uso = self.name in ['flux-dev-uso'] - image_stiching = not self.name in ['flux-dev-uso'] #and False + flux_dev_umo = self.name in ['flux-dev-umo'] + latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo'] + + lock_dimensions= False + input_ref_images = [] if input_ref_images is None else input_ref_images[:] + if flux_dev_umo: + ref_long_side = 512 if len(input_ref_images) <= 1 else 320 + input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images] + lock_dimensions = True ref_style_imgs = [] if "I" in video_prompt_type and len(input_ref_images) > 0: if flux_dev_uso : @@ -143,22 +183,26 @@ def generate( elif len(input_ref_images) > 1 : ref_style_imgs = input_ref_images[-1:] input_ref_images = input_ref_images[:-1] - if image_stiching: + + if latent_stiching: + # latents stiching with resize + if not lock_dimensions : + for i in range(len(input_ref_images)): + w, h = input_ref_images[i].size + image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0) + input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + else: # image stiching method stiched = input_ref_images[0] for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] - else: - # latents stiching with resize - for i in range(len(input_ref_images)): - w, h = input_ref_images[i].size - image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) - input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + elif image_guide is not None: + input_ref_images = [image_guide] else: input_ref_images = None - if flux_dev_uso : + if self.name in ['flux-dev-uso', 'flux-dev-umo'] : inp, height, width = prepare_multi_ip( ae=self.vae, img_cond_list=input_ref_images, @@ -177,6 +221,7 @@ def generate( bs=batch_size, seed=seed, device=device, + img_mask=image_mask, ) inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt)) @@ -198,13 +243,19 @@ def unpack_latent(x): return unpack(x.float(), height, width) # denoise initial noise - x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass) + x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength) if x==None: return None # decode latents to pixel space x = unpack_latent(x) with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) + if image_mask is not None: + from shared.utils.utils import convert_image_to_tensor + img_msk_rebuilt = inp["img_msk_rebuilt"] + img= convert_image_to_tensor(image_guide) + x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt + x = x.clamp(-1, 1) x = x.transpose(0, 1) return x diff --git a/models/flux/model.py b/models/flux/model.py index c4642d0bc..c5f7a24b4 100644 --- a/models/flux/model.py +++ b/models/flux/model.py @@ -190,6 +190,21 @@ def swap_scale_shift(weight): v = swap_scale_shift(v) k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1") new_sd[k] = v + # elif not first_key.startswith("diffusion_model.") and not first_key.startswith("transformer."): + # for k,v in sd.items(): + # if "double" in k: + # k = k.replace(".processor.proj_lora1.", ".img_attn.proj.lora_") + # k = k.replace(".processor.proj_lora2.", ".txt_attn.proj.lora_") + # k = k.replace(".processor.qkv_lora1.", ".img_attn.qkv.lora_") + # k = k.replace(".processor.qkv_lora2.", ".txt_attn.qkv.lora_") + # else: + # k = k.replace(".processor.qkv_lora.", ".linear1_qkv.lora_") + # k = k.replace(".processor.proj_lora.", ".linear2.lora_") + + # k = "diffusion_model." + k + # new_sd[k] = v + # from mmgp import safetensors2 + # safetensors2.torch_write_file(new_sd, "fff.safetensors") else: new_sd = sd return new_sd diff --git a/models/flux/sampling.py b/models/flux/sampling.py index f43ae15e0..1b4813a5c 100644 --- a/models/flux/sampling.py +++ b/models/flux/sampling.py @@ -138,10 +138,12 @@ def prepare_kontext( target_width: int | None = None, target_height: int | None = None, bs: int = 1, - + img_mask = None, ) -> tuple[dict[str, Tensor], int, int]: # load and encode the conditioning image + res_match_output = img_mask is not None + img_cond_seq = None img_cond_seq_ids = None if img_cond_list == None: img_cond_list = [] @@ -150,9 +152,11 @@ def prepare_kontext( for cond_no, img_cond in enumerate(img_cond_list): width, height = img_cond.size aspect_ratio = width / height - - # Kontext is trained on specific resolutions, using one of them is recommended - _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) + if res_match_output: + width, height = target_width, target_height + else: + # Kontext is trained on specific resolutions, using one of them is recommended + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) width = 2 * int(width / 16) height = 2 * int(height / 16) @@ -193,6 +197,19 @@ def prepare_kontext( "img_cond_seq": img_cond_seq, "img_cond_seq_ids": img_cond_seq_ids, } + if img_mask is not None: + from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image + # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) + image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS)) + image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] + image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) + convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) + return_dict.update({ + "img_msk_latents": image_mask_latents, + "img_msk_rebuilt": image_mask_rebuilt, + }) + img = get_noise( bs, target_height, @@ -264,6 +281,9 @@ def denoise( loras_slists=None, unpack_latent = None, joint_pass= False, + img_msk_latents = None, + img_msk_rebuilt = None, + denoising_strength = 1, ): kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids} @@ -271,6 +291,21 @@ def denoise( if callback != None: callback(-1, None, True) + original_image_latents = None if img_cond_seq is None else img_cond_seq.clone() + + morph, first_step = False, 0 + if img_msk_latents is not None: + randn = torch.randn_like(original_image_latents) + if denoising_strength < 1.: + first_step = int(len(timesteps) * (1. - denoising_strength)) + if not morph: + latent_noise_factor = timesteps[first_step] + latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + img = latents.to(img) + latents = None + timesteps = timesteps[first_step:] + + updated_num_steps= len(timesteps) -1 if callback != None: from shared.utils.loras_mutipliers import update_loras_slists @@ -280,10 +315,14 @@ def denoise( # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): - offload.set_step_no_for_lora(model, i) + offload.set_step_no_for_lora(model, first_step + i) if pipeline._interrupt: return None + if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph: + latent_noise_factor = t_curr/1000 + img = original_image_latents * (1.0 - latent_noise_factor) + img * latent_noise_factor + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) img_input = img img_input_ids = img_ids @@ -333,6 +372,14 @@ def denoise( pred = neg_pred + real_guidance_scale * (pred - neg_pred) img += (t_prev - t_curr) * pred + + if img_msk_latents is not None: + latent_noise_factor = t_prev + # noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + img = noisy_image * (1-img_msk_latents) + img_msk_latents * img + noisy_image = None + if callback is not None: preview = unpack_latent(img).transpose(0,1) callback(i, preview, False) diff --git a/models/flux/util.py b/models/flux/util.py index 0f96103de..af75f6269 100644 --- a/models/flux/util.py +++ b/models/flux/util.py @@ -640,6 +640,38 @@ class ModelSpec: shift_factor=0.1159, ), ), + "flux-dev-umo": ModelSpec( + repo_id="", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + eso= True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), } diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 20838f595..0897ee419 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -714,14 +714,14 @@ def __call__( image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS)) image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) - convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) prompt_image = image if image.size != (image_width, image_height): image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS) - image.save("nnn.png") + # image.save("nnn.png") image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2) has_neg_prompt = negative_prompt is not None or ( @@ -811,12 +811,15 @@ def __call__( negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) - morph = False - if image_mask_latents is not None and denoising_strength <= 1.: - first_step = int(len(timesteps) * (1. - denoising_strength)) + morph, first_step = False, 0 + if image_mask_latents is not None: + randn = torch.randn_like(original_image_latents) + if denoising_strength < 1.: + first_step = int(len(timesteps) * (1. - denoising_strength)) if not morph: latent_noise_factor = timesteps[first_step]/1000 - latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + # latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor timesteps = timesteps[first_step:] self.scheduler.timesteps = timesteps self.scheduler.sigmas= self.scheduler.sigmas[first_step:] @@ -831,6 +834,7 @@ def __call__( for i, t in enumerate(timesteps): + offload.set_step_no_for_lora(self.transformer, first_step + i) if self.interrupt: continue @@ -905,7 +909,8 @@ def __call__( if image_mask_latents is not None: next_t = timesteps[i+1] if i sliding_window_size: + if model_type in ["t2v"] and not "G" in video_prompt_type : + gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.") + return full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1 extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation" no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated") - if "recam" in model_filename: if video_guide == None: gr.Info("You must provide a Control Video") @@ -7019,28 +7020,38 @@ def categorize_resolution(resolution_str): return group return "1440p" -def group_resolutions(resolutions, selected_resolution): - - grouped_resolutions = {} - for resolution in resolutions: - group = categorize_resolution(resolution[1]) - if group not in grouped_resolutions: - grouped_resolutions[group] = [] - grouped_resolutions[group].append(resolution) - - available_groups = [group for group in group_thresholds if group in grouped_resolutions] +def group_resolutions(model_def, resolutions, selected_resolution): + + model_resolutions = model_def.get("resolutions", None) + if model_resolutions is not None: + selected_group ="Locked" + available_groups = [selected_group ] + selected_group_resolutions = model_resolutions + else: + grouped_resolutions = {} + for resolution in resolutions: + group = categorize_resolution(resolution[1]) + if group not in grouped_resolutions: + grouped_resolutions[group] = [] + grouped_resolutions[group].append(resolution) + + available_groups = [group for group in group_thresholds if group in grouped_resolutions] - selected_group = categorize_resolution(selected_resolution) - selected_group_resolutions = grouped_resolutions.get(selected_group, []) - available_groups.reverse() + selected_group = categorize_resolution(selected_resolution) + selected_group_resolutions = grouped_resolutions.get(selected_group, []) + available_groups.reverse() return available_groups, selected_group_resolutions, selected_group def change_resolution_group(state, selected_group): model_type = state["model_type"] model_def = get_model_def(model_type) model_resolutions = model_def.get("resolutions", None) - resolution_choices, _ = get_resolution_choices(None, model_resolutions) - group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ] + resolution_choices, _ = get_resolution_choices(None, model_resolutions) + if model_resolutions is None: + group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ] + else: + last_resolution = group_resolution_choices[0][1] + return gr.update(choices= group_resolution_choices, value= last_resolution) last_resolution_per_group = state["last_resolution_per_group"] last_resolution = last_resolution_per_group.get(selected_group, "") @@ -7051,6 +7062,11 @@ def change_resolution_group(state, selected_group): def record_last_resolution(state, resolution): + + model_type = state["model_type"] + model_def = get_model_def(model_type) + model_resolutions = model_def.get("resolutions", None) + if model_resolutions is not None: return server_config["last_resolution_choice"] = resolution selected_group = categorize_resolution(resolution) last_resolution_per_group = state["last_resolution_per_group"] @@ -7482,11 +7498,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) - no_background_removal = model_def.get("no_background_removal", False) + no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None + background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects") + remove_background_images_ref = gr.Dropdown( choices=[ ("Keep Backgrounds behind all Reference Images", 0), - ("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else ("Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" if vace else "Remove Backgrounds behind People / Objects") , 1), + (background_removal_label, 1), ], value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1), label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal @@ -7578,7 +7596,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl current_resolution_choice = ui_defaults.get("resolution","832x480") if update_form or last_resolution is None else last_resolution model_resolutions = model_def.get("resolutions", None) resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions) - available_groups, selected_group_resolutions, selected_group = group_resolutions(resolution_choices, current_resolution_choice) + available_groups, selected_group_resolutions, selected_group = group_resolutions(model_def,resolution_choices, current_resolution_choice) resolution_group = gr.Dropdown( choices = available_groups, value= selected_group, From 7bcd7246217954956536515d8b220f349cb5f288 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 15 Sep 2025 19:28:18 +0200 Subject: [PATCH 026/155] attack of the clones --- README.md | 228 +--------- defaults/qwen_image_edit_20B.json | 1 + defaults/vace_fun_14B_2_2.json | 24 ++ defaults/vace_fun_14B_cocktail_2_2.json | 28 ++ docs/CHANGELOG.md | 144 ++++++- models/flux/flux_handler.py | 2 +- models/flux/sampling.py | 2 +- models/hyvideo/hunyuan_handler.py | 10 +- models/ltx_video/ltxv.py | 4 +- models/ltx_video/ltxv_handler.py | 2 +- models/qwen/pipeline_qwenimage.py | 16 +- models/qwen/qwen_handler.py | 27 +- models/qwen/qwen_main.py | 14 +- models/wan/any2video.py | 86 ++-- models/wan/df_handler.py | 2 +- models/wan/multitalk/multitalk.py | 18 +- models/wan/wan_handler.py | 9 +- preprocessing/extract_vocals.py | 69 ++++ preprocessing/speakers_separator.py | 12 +- requirements.txt | 5 +- wgp.py | 526 +++++++++++++++--------- 21 files changed, 755 insertions(+), 474 deletions(-) create mode 100644 defaults/vace_fun_14B_2_2.json create mode 100644 defaults/vace_fun_14B_cocktail_2_2.json create mode 100644 preprocessing/extract_vocals.py diff --git a/README.md b/README.md index d33b6dc4b..2411cea3d 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,19 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### September 15 2025: WanGP v8.6 - Attack of the Clones + +- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) + +- **First Frame / Last Frame for Vace** : Vace model are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias : + +For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on .... +If you *Continue a Video* , you just need *"L L L"* since the the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*. + +- **Qwen Inpainting** exist now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area. + +- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video. + ### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ? I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof. @@ -74,221 +87,6 @@ You will find below a 33s movie I have created using these two methods. Quality *update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs* -### August 29 2025: WanGP v8.21 - Here Goes Your Weekend - -- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) - -- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn - -- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis - -### August 24 2025: WanGP v8.1 - the RAM Liberator - -- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP -- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\ -If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators. -- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps - -### August 21 2025: WanGP v8.01 - the killer of seven - -- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good. -- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt. -- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work -- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading -- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1. - -*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation* -### August 12 2025: WanGP v7.7777 - Lucky Day(s) - -This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything ! - -Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing. - -Support: -- Video: x264, x264 lossless, x265 -- Images: jpeg, png, webp, wbp lossless -Generation Settings are stored in each of the above regardless of the format (that was the hard part). - -Also you can now choose different output directories for images and videos. - -unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. -*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time* -*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?* -- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps) -- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps) - - -### August 10 2025: WanGP v7.76 - Faster than the VAE ... -We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... -Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune. - -*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)* - -### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2 -Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. - -Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster ! - -### August 8 2025: WanGP v7.73 - Qwen Rebirth -Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. - -As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) - -Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures. - -*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case* - - -### August 6 2025: WanGP v7.71 - Picky, picky - -This release comes with two new models : -- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals -- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) - -There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. - -*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* - - -### August 4 2025: WanGP v7.6 - Remuxed - -With this new version you won't have any excuse if there is no sound in your video. - -*Continue Video* now works with any video that has already some sound (hint: Multitalk ). - -Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. - -As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. - -For instance: -- first video part: use Multitalk with two people speaking -- second video part: you apply your own soundtrack which will gently follow the multitalk conversation -- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio - -To multiply the combinations I have also implemented *Continue Video* with the various image2video models. - -Also: -- End Frame support added for LTX Video models -- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides -- Flux Krea Dev support - -### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 -Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... - -Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder. - -I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference... - -And this time I really removed Vace Cocktail Light which gave a blurry vision. - -### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview -Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. - -So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. - -However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** ! - -Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation - -Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... - -7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well. - -### July 27 2025: WanGP v7.3 : Interlude -While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. - -### July 26 2025: WanGP v7.2 : Ode to Vace -I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk. - -Here are some new Vace improvements: -- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX. -- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). -- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged. -- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code. - -Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats. - - -### July 21 2025: WanGP v7.12 -- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. - -- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App) - -- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. - -- LTX IC-Lora support: these are special Loras that consumes a conditional image or video -Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. - -- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update) - -- Easier way to select video resolution - - -### July 15 2025: WanGP v7.0 is an AI Powered Photoshop -This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : -- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB -- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer -- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ... -- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation - -And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\ -As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\ -This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization. - -WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras. - -Also in the news: -- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected. -- *Film Grain* post processing to add a vintage look at your video -- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete -- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated. - - -### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me -Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. - -So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models. - -Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters. - -The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence. - -### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : -**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). - -Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus. - -And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people. - -As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors. - -But wait, there is more: -- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you) -- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes. - -Also, of interest too: -- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos -- Force the generated video fps to your liking, works wery well with Vace when using a Control Video -- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time) - -### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features: -- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations -- In one click use the newly generated video as a Control Video or Source Video to be continued -- Manage multiple settings for the same model and switch between them using a dropdown box -- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open -- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder) - -Taking care of your life is not enough, you want new stuff to play with ? -- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio -- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best -- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps -- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage -- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video -- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer -- Choice of Wan Samplers / Schedulers -- More Lora formats support - -**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one** See full changelog: **[Changelog](docs/CHANGELOG.md)** diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json index 79b8b245d..04fc573fb 100644 --- a/defaults/qwen_image_edit_20B.json +++ b/defaults/qwen_image_edit_20B.json @@ -7,6 +7,7 @@ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors" ], + "preload_URLs": ["https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_inpainting.safetensors"], "attention": { "<89": "sdpa" } diff --git a/defaults/vace_fun_14B_2_2.json b/defaults/vace_fun_14B_2_2.json new file mode 100644 index 000000000..8f22d3456 --- /dev/null +++ b/defaults/vace_fun_14B_2_2.json @@ -0,0 +1,24 @@ +{ + "model": { + "name": "Wan2.2 Vace Fun 14B", + "architecture": "vace_14B", + "description": "This is the Fun Vace 2.2 version, that is not the official Vace 2.2", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "num_inference_steps": 30, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold": 875 +} \ No newline at end of file diff --git a/defaults/vace_fun_14B_cocktail_2_2.json b/defaults/vace_fun_14B_cocktail_2_2.json new file mode 100644 index 000000000..c587abdd4 --- /dev/null +++ b/defaults/vace_fun_14B_cocktail_2_2.json @@ -0,0 +1,28 @@ +{ + "model": { + "name": "Wan2.2 Vace Fun Cocktail 14B", + "architecture": "vace_14B", + "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. This is the Fun Vace 2.2, that is not the official Vace 2.2", + "URLs": "vace_fun_14B_2_2", + "URLs2": "vace_fun_14B_2_2", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": [ + 1, + 0.2, + 0.5, + 0.5 + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold": 875 +} \ No newline at end of file diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 5a89d9389..b0eeae34b 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,20 +1,154 @@ # Changelog ## 🔥 Latest News -### July 21 2025: WanGP v7.1 +### August 29 2025: WanGP v8.21 - Here Goes Your Weekend + +- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) + +- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn + +- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis + +### August 24 2025: WanGP v8.1 - the RAM Liberator + +- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP +- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\ +If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators. +- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps + +### August 21 2025: WanGP v8.01 - the killer of seven + +- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good. +- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt. +- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work +- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading +- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1. + +*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation* +### August 12 2025: WanGP v7.7777 - Lucky Day(s) + +This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything ! + +Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing. + +Support: +- Video: x264, x264 lossless, x265 +- Images: jpeg, png, webp, wbp lossless +Generation Settings are stored in each of the above regardless of the format (that was the hard part). + +Also you can now choose different output directories for images and videos. + +unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. +*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time* +*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?* +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps) +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps) + + +### August 10 2025: WanGP v7.76 - Faster than the VAE ... +We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... +Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune. + +*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)* + +### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2 +Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. + +Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster ! + +### August 8 2025: WanGP v7.73 - Qwen Rebirth +Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. + +As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) + +Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures. + +*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case* + + +### August 6 2025: WanGP v7.71 - Picky, picky + +This release comes with two new models : +- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals +- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) + +There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. + +*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* + + +### August 4 2025: WanGP v7.6 - Remuxed + +With this new version you won't have any excuse if there is no sound in your video. + +*Continue Video* now works with any video that has already some sound (hint: Multitalk ). + +Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. + +As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. + +For instance: +- first video part: use Multitalk with two people speaking +- second video part: you apply your own soundtrack which will gently follow the multitalk conversation +- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio + +To multiply the combinations I have also implemented *Continue Video* with the various image2video models. + +Also: +- End Frame support added for LTX Video models +- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides +- Flux Krea Dev support + +### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 +Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... + +Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder. + +I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference... + +And this time I really removed Vace Cocktail Light which gave a blurry vision. + +### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview +Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. + +So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. + +However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** ! + +Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation + +Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... + +7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well. + +### July 27 2025: WanGP v7.3 : Interlude +While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. + +### July 26 2025: WanGP v7.2 : Ode to Vace +I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk. + +Here are some new Vace improvements: +- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX. +- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). +- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged. +- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code. + +Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats. + + +### July 21 2025: WanGP v7.12 - Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. -- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment +- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App) - LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. - LTX IC-Lora support: these are special Loras that consumes a conditional image or video Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. -And Also: -- easier way to select video resolution -- started to optimize Matanyone to reduce VRAM requirements +- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update) +- Easier way to select video resolution ### July 15 2025: WanGP v7.0 is an AI Powered Photoshop This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index 808369ff6..d16888122 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -107,7 +107,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder ] @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .flux_main import model_factory flux_model = model_factory( diff --git a/models/flux/sampling.py b/models/flux/sampling.py index 1b4813a5c..92ee59099 100644 --- a/models/flux/sampling.py +++ b/models/flux/sampling.py @@ -203,7 +203,7 @@ def prepare_kontext( image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS)) image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) - convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) return_dict.update({ "img_msk_latents": image_mask_latents, diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index 67e9e99da..435234102 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -68,7 +68,13 @@ def query_model_def(base_model_type, model_def): "visible": False, } - if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True + if base_model_type in ["hunyuan_avatar"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Start Image", "KI")], + "letters_filter":"KI", + "visible": False, + } + extra_model_def["no_background_removal"] = True if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]: extra_model_def["one_image_ref_needed"] = True @@ -123,7 +129,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder } @staticmethod - def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .hunyuan import HunyuanVideoSampler from mmgp import offload diff --git a/models/ltx_video/ltxv.py b/models/ltx_video/ltxv.py index db143fc6d..080860c14 100644 --- a/models/ltx_video/ltxv.py +++ b/models/ltx_video/ltxv.py @@ -476,14 +476,14 @@ def generate( images = images.sub_(0.5).mul_(2).squeeze(0) return images - def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs): + def get_loras_transformer(self, get_model_recursive_prop, model_type, video_prompt_type, **kwargs): map = { "P" : "pose", "D" : "depth", "E" : "canny", } loras = [] - preloadURLs = get_model_recursive_prop(self.model_type, "preload_URLs") + preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") lora_file_name = "" for letter, signature in map.items(): if letter in video_prompt_type: diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index e44c98399..2845fdb0b 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -74,7 +74,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .ltxv import LTXV ltxv_model = LTXV( diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 0897ee419..be982aabc 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -569,6 +569,8 @@ def __call__( pipeline=None, loras_slists=None, joint_pass= True, + lora_inpaint = False, + outpainting_dims = None, ): r""" Function invoked when calling the pipeline for generation. @@ -704,7 +706,7 @@ def __call__( image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) if (image_width,image_height) != image.size: image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) - else: + elif not lora_inpaint: # _, image_width, image_height = min( # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS # ) @@ -721,8 +723,16 @@ def __call__( if image.size != (image_width, image_height): image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + image = convert_image_to_tensor(image) + if lora_inpaint: + image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1] + image_mask_latents = None + green = torch.tensor([-1.0, 1.0, -1.0]).to(image) + green_image = green[:, None, None] .expand_as(image) + image = torch.where(image_mask_rebuilt > 0, green_image, image) + prompt_image = convert_tensor_to_image(image) + image = image.unsqueeze(0).unsqueeze(2) # image.save("nnn.png") - image = convert_image_to_tensor(image).unsqueeze(0).unsqueeze(2) has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None @@ -940,7 +950,7 @@ def __call__( ) latents = latents / latents_std + latents_mean output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] - if image_mask is not None: + if image_mask is not None and not lora_inpaint : #not (lora_inpaint and outpainting_dims is not None): output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 80a909f84..010298ead 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -1,4 +1,6 @@ import torch +import gradio as gr + def get_qwen_text_encoder_filename(text_encoder_quantization): text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors" @@ -29,6 +31,16 @@ def query_model_def(base_model_type, model_def): "letters_filter": "KI", } extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape" + extra_model_def["video_guide_outpainting"] = [2] + extra_model_def["model_modes"] = { + "choices": [ + ("Lora Inpainting: Inpainted area completely unrelated to occulted content", 1), + ("Masked Denoising : Inpainted area may reuse some content that has been occulted", 0), + ], + "default": 1, + "label" : "Inpainting Method", + "image_modes" : [2], + } return extra_model_def @@ -58,7 +70,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder } @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .qwen_main import model_factory from mmgp import offload @@ -99,5 +111,18 @@ def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "video_prompt_type": "KI", "denoising_strength" : 1., + "model_mode" : 0, }) + def validate_generative_settings(base_model_type, model_def, inputs): + if base_model_type in ["qwen_image_edit_20B"]: + model_mode = inputs["model_mode"] + denoising_strength= inputs["denoising_strength"] + video_guide_outpainting= inputs["video_guide_outpainting"] + from wgp import get_outpainting_dims + outpainting_dims = get_outpainting_dims(video_guide_outpainting) + + if denoising_strength < 1 and model_mode == 1: + gr.Info("Denoising Strength will be ignored while using Lora Inpainting") + if outpainting_dims is not None and model_mode == 0 : + return "Outpainting is not supported with Masked Denoising " diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index e4c19eddb..da84e0d0e 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -44,7 +44,7 @@ def __init__( save_quantized = False, dtype = torch.bfloat16, VAE_dtype = torch.float32, - mixed_precision_transformer = False + mixed_precision_transformer = False, ): @@ -117,6 +117,8 @@ def generate( joint_pass = True, sample_solver='default', denoising_strength = 1., + model_mode = 0, + outpainting_dims = None, **bbargs ): # Generate with different aspect ratios @@ -205,8 +207,16 @@ def generate( loras_slists=loras_slists, joint_pass = joint_pass, denoising_strength=denoising_strength, - generator=torch.Generator(device="cuda").manual_seed(seed) + generator=torch.Generator(device="cuda").manual_seed(seed), + lora_inpaint = image_mask is not None and model_mode == 1, + outpainting_dims = outpainting_dims, ) if image is None: return None return image.transpose(0, 1) + def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs): + if model_mode == 0: return [], [] + preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") + return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] + + diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 40ed42fa5..dde1a6535 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -64,6 +64,7 @@ def __init__( config, checkpoint_dir, model_filename = None, + submodel_no_list = None, model_type = None, model_def = None, base_model_type = None, @@ -126,50 +127,65 @@ def __init__( forcedConfigPath = base_config_file if len(model_filename) > 1 else None # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename - + self.model = self.model2 = None source = model_def.get("source", None) + source2 = model_def.get("source2", None) module_source = model_def.get("module_source", None) + module_source2 = model_def.get("module_source2", None) if module_source is not None: - model_filename = [] + model_filename - model_filename[1] = module_source - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - elif source is not None: + self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if module_source2 is not None: + self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if source is not None: self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) - elif self.transformer_switch: - shared_modules= {} - self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) - self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - shared_modules = None + if source2 is not None: + self.model2 = offload.fast_load_transformers_model(source2, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) + + if self.model is not None or self.model2 is not None: + from wgp import save_model + from mmgp.safetensors2 import torch_load_file else: - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if self.transformer_switch: + if 0 in submodel_no_list[2:] and 1 in submodel_no_list: + raise Exception("Shared and non shared modules at the same time across multipe models is not supported") + + if 0 in submodel_no_list[2:]: + shared_modules= {} + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + shared_modules = None + else: + modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ] + modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ] + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + + else: + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - # self.model = offload.load_model_data(self.model, xmodel_filename ) - # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") - self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) - offload.change_dtype(self.model, dtype, True) + if self.model is not None: + self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) + offload.change_dtype(self.model, dtype, True) + self.model.eval().requires_grad_(False) if self.model2 is not None: self.model2.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model2, dtype, True) - - # offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd) - # offload.save_model(self.model, "wan2.2_image2video_14B_low_mbf16.safetensors", config_file_path=base_config_file) - # offload.save_model(self.model, "wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) - self.model.eval().requires_grad_(False) - if self.model2 is not None: self.model2.eval().requires_grad_(False) + if module_source is not None: - from wgp import save_model - from mmgp.safetensors2 import torch_load_file - filter = list(torch_load_file(module_source)) - save_model(self.model, model_type, dtype, None, is_module=True, filter=filter) - elif not source is None: - from wgp import save_model - save_model(self.model, model_type, dtype, None) + save_model(self.model, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source)), module_source_no=1) + if module_source2 is not None: + save_model(self.model2, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source2)), module_source_no=2) + if not source is None: + save_model(self.model, model_type, dtype, None, submodel_no= 1) + if not source2 is None: + save_model(self.model2, model_type, dtype, None, submodel_no= 2) if save_quantized: from wgp import save_quantized_model - save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) + if self.model is not None: + save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) if self.model2 is not None: save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2) self.sample_neg_prompt = config.sample_neg_prompt @@ -307,7 +323,7 @@ def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, full_ canvas = canvas.to(device) return ref_img.to(device), canvas - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): + def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): image_sizes = [] trim_video_guide = len(keep_video_guide_frames) def conv_tensor(t, device): @@ -659,13 +675,15 @@ def generate(self, inject_from_start = False if input_frames != None and denoising_strength < 1 : color_reference_frame = input_frames[:, -1:].clone() - if overlapped_latents != None: - overlapped_latents_frames_num = overlapped_latents.shape[2] - overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 + if prefix_frames_count > 0: + overlapped_frames_num = prefix_frames_count + overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1 + # overlapped_latents_frames_num = overlapped_latents.shape[2] + # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 else: overlapped_latents_frames_num = overlapped_frames_num = 0 if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] - injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) + injection_denoising_step = int( round(sampling_steps * (1. - denoising_strength),4) ) latent_keep_frames = [] if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0: inject_from_start = True diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index 39d0a70c8..7a5d3ea2d 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -78,7 +78,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None): from .configs import WAN_CONFIGS from .wan_handler import family_handler cfg = WAN_CONFIGS['t2v-14B'] diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py index b04f65f32..52d2dd995 100644 --- a/models/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -214,18 +214,20 @@ def process_tts_multi(text, save_dir, voice1, voice2): return s1, s2, save_path_sum -def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0): +def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False): wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") pad = int(padded_frames_for_embeddings/ fps * sr) new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration ) - audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - full_audio_embs = [] - if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) - # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) - if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) - if audio_guide2 == None and not duration_changed: sum_human_speechs = None + if return_sum_only: + full_audio_embs = None + else: + audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + full_audio_embs = [] + if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) + if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) + if audio_guide2 == None and not duration_changed: sum_human_speechs = None return full_audio_embs, sum_human_speechs diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 253bd9236..6586176bd 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -166,7 +166,8 @@ def query_model_def(base_model_type, model_def): extra_model_def["lock_image_refs_ratios"] = True extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" - + extra_model_def["video_guide_outpainting"] = [0,1] + if base_model_type in ["standin"]: extra_model_def["lock_image_refs_ratios"] = True extra_model_def["image_ref_choices"] = { @@ -293,7 +294,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None): from .configs import WAN_CONFIGS if test_class_i2v(base_model_type): @@ -306,6 +307,7 @@ def load_model(model_filename, model_type, base_model_type, model_def, quantizeT config=cfg, checkpoint_dir="ckpts", model_filename=model_filename, + submodel_no_list = submodel_no_list, model_type = model_type, model_def = model_def, base_model_type=base_model_type, @@ -381,7 +383,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): if base_model_type in ["fantasy"]: ui_defaults.update({ "audio_guidance_scale": 5.0, - "sliding_window_size": 1, + "sliding_window_overlap" : 1, }) elif base_model_type in ["multitalk"]: @@ -398,6 +400,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "guidance_scale": 5.0, "flow_shift": 7, # 11 for 720p "sliding_window_overlap" : 9, + "sliding_window_size": 81, "sample_solver" : "euler", "video_prompt_type": "QKI", "remove_background_images_ref" : 0, diff --git a/preprocessing/extract_vocals.py b/preprocessing/extract_vocals.py new file mode 100644 index 000000000..656402644 --- /dev/null +++ b/preprocessing/extract_vocals.py @@ -0,0 +1,69 @@ +from pathlib import Path +import os, tempfile +import numpy as np +import soundfile as sf +import librosa +import torch +import gc + +from audio_separator.separator import Separator + +def get_vocals(src_path: str, dst_path: str, min_seconds: float = 8) -> str: + """ + If the source audio is shorter than `min_seconds`, pad with trailing silence + in a temporary file, then run separation and save only the vocals to dst_path. + Returns the full path to the vocals file. + """ + + default_device = torch.get_default_device() + torch.set_default_device('cpu') + + dst = Path(dst_path) + dst.parent.mkdir(parents=True, exist_ok=True) + + # Quick duration check + duration = librosa.get_duration(path=src_path) + + use_path = src_path + temp_path = None + try: + if duration < min_seconds: + # Load (resample) and pad in memory + y, sr = librosa.load(src_path, sr=None, mono=False) + if y.ndim == 1: # ensure shape (channels, samples) + y = y[np.newaxis, :] + target_len = int(min_seconds * sr) + pad = max(0, target_len - y.shape[1]) + if pad: + y = np.pad(y, ((0, 0), (0, pad)), mode="constant") + + # Write a temp WAV for the separator + fd, temp_path = tempfile.mkstemp(suffix=".wav") + os.close(fd) + sf.write(temp_path, y.T, sr) # soundfile expects (frames, channels) + use_path = temp_path + + # Run separation: emit only the vocals, with your exact filename + sep = Separator( + output_dir=str(dst.parent), + output_format=(dst.suffix.lstrip(".") or "wav"), + output_single_stem="Vocals", + model_file_dir="ckpts/roformer/" #model_bs_roformer_ep_317_sdr_12.9755.ckpt" + ) + sep.load_model() + out_files = sep.separate(use_path, {"Vocals": dst.stem}) + + out = Path(out_files[0]) + return str(out if out.is_absolute() else (dst.parent / out)) + finally: + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + torch.cuda.empty_cache() + gc.collect() + torch.set_default_device(default_device) + +# Example: +# final = extract_vocals("in/clip.mp3", "out/vocals.wav") +# print(final) + diff --git a/preprocessing/speakers_separator.py b/preprocessing/speakers_separator.py index 79cde9b1e..d5f256335 100644 --- a/preprocessing/speakers_separator.py +++ b/preprocessing/speakers_separator.py @@ -100,7 +100,7 @@ def __init__(self, hf_token: str = None, local_model_path: str = None, self.hf_token = hf_token self._overlap_pipeline = None - def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]: + def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None ) -> Dict[str, str]: """Optimized main separation function with memory management.""" xprint("Starting optimized audio separation...") self._current_audio_path = os.path.abspath(audio_path) @@ -128,7 +128,11 @@ def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]: gc.collect() # Save outputs efficiently - output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2) + if audio_original_path is None: + waveform_original = waveform + else: + waveform_original, sample_rate = self.load_audio(audio_original_path) + output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2) return output_paths @@ -835,7 +839,7 @@ def print_summary(self, audio_path: str): for turn, _, speaker in diarization.itertracks(yield_label=True): xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s") -def extract_dual_audio(audio, output1, output2, verbose = False): +def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None): global verbose_output verbose_output = verbose separator = OptimizedPyannote31SpeakerSeparator( @@ -848,7 +852,7 @@ def extract_dual_audio(audio, output1, output2, verbose = False): import time start_time = time.time() - outputs = separator.separate_audio(audio, output1, output2) + outputs = separator.separate_audio(audio, output1, output2, audio_original) elapsed_time = time.time() - start_time xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===") diff --git a/requirements.txt b/requirements.txt index 1f008c75f..e6c85fda7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,14 +21,15 @@ mutagen pyloudnorm librosa==0.11.0 speechbrain==1.0.3 - +audio-separator==0.36.1 + # UI & interaction gradio==5.29.0 dashscope loguru # Vision & segmentation -opencv-python>=4.9.0.80 +opencv-python>=4.12.0.88 segment-anything rembg[gpu]==2.0.65 onnxruntime-gpu diff --git a/wgp.py b/wgp.py index d085be76c..de8e2f66f 100644 --- a/wgp.py +++ b/wgp.py @@ -23,6 +23,7 @@ from shared.utils import notification_sound from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask +from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture @@ -61,8 +62,8 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.55" -settings_version = 2.34 +WanGP_version = "8.6" +settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -199,9 +200,11 @@ def clean_image_list(gradio_list): def process_prompt_and_add_tasks(state, model_choice): - + def ret(): + return gr.update(), gr.update() + if state.get("validate_success",0) != 1: - return + ret() state["validate_success"] = 0 model_filename = state["model_filename"] @@ -218,7 +221,7 @@ def process_prompt_and_add_tasks(state, model_choice): if inputs == None: gr.Warning("Internal state error: Could not retrieve inputs for the model.") queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() model_def = get_model_def(model_type) model_handler = get_model_handler(model_type) image_outputs = inputs["image_mode"] > 0 @@ -247,7 +250,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"] if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0: gr.Info("Temporal Upsampling can not be used with an Image") - return + return ret() film_grain_intensity = inputs.get("film_grain_intensity",0) film_grain_saturation = inputs.get("film_grain_saturation",0.5) # if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"] @@ -263,7 +266,7 @@ def process_prompt_and_add_tasks(state, model_choice): else: if audio_source is None: gr.Info("You must provide a custom Audio") - return + return ret() prompt += ["Custom Audio"] repeat_generation == 1 @@ -273,32 +276,32 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("You must choose at least one Remux Method") else: gr.Info("You must choose at least one Post Processing Method") - return + return ret() inputs["prompt"] = ", ".join(prompt) add_video_task(**inputs) gen["prompts_max"] = 1 + gen.get("prompts_max",0) state["validate_success"] = 1 queue= gen.get("queue", []) - return update_queue_data(queue) + return ret() if hasattr(model_handler, "validate_generative_settings"): error = model_handler.validate_generative_settings(model_type, model_def, inputs) if error is not None and len(error) > 0: gr.Info(error) - return + return ret() if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0: gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time") - return + return ret() prompt = inputs["prompt"] if len(prompt) ==0: gr.Info("Prompt cannot be empty.") gen = get_gen_info(state) queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() prompt, errors = prompt_parser.process_template(prompt) if len(errors) > 0: gr.Info("Error processing prompt template: " + errors) - return + return ret() model_filename = get_model_filename(model_type) prompts = prompt.replace("\r", "").split("\n") prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] @@ -306,7 +309,7 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("Prompt cannot be empty.") gen = get_gen_info(state) queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() resolution = inputs["resolution"] width, height = resolution.split("x") @@ -360,22 +363,22 @@ def process_prompt_and_add_tasks(state, model_choice): _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases) if len(errors) > 0: gr.Info(f"Error parsing Loras Multipliers: {errors}") - return + return ret() if guidance_phases == 3: if switch_threshold < switch_threshold2: gr.Info(f"Phase 1-2 Switch Noise Level ({switch_threshold}) should be Greater than Phase 2-3 Switch Noise Level ({switch_threshold2}). As a reminder, noise will gradually go down from 1000 to 0.") - return + return ret() else: model_switch_phase = 1 if not any_steps_skipping: skip_steps_cache_type = "" if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: gr.Info("The minimum number of steps should be 20") - return + return ret() if skip_steps_cache_type == "mag": if num_inference_steps > 50: gr.Info("Mag Cache maximum number of steps is 50") - return + return ret() if image_mode > 0: audio_prompt_type = "" @@ -385,7 +388,7 @@ def process_prompt_and_add_tasks(state, model_choice): speakers_bboxes, error = parse_speakers_locations(speakers_locations) if len(error) > 0: gr.Info(error) - return + return ret() if MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and video_length <16: #should depend on the architecture gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") @@ -393,23 +396,24 @@ def process_prompt_and_add_tasks(state, model_choice): if len(frames_positions.strip()) > 0: positions = frames_positions.split(" ") for pos_str in positions: - if not is_integer(pos_str): - gr.Info(f"Invalid Frame Position '{pos_str}'") - return - pos = int(pos_str) - if pos <1 or pos > max_source_video_frames: - gr.Info(f"Invalid Frame Position Value'{pos_str}'") - return + if not pos_str in ["L", "l"] and len(pos_str)>0: + if not is_integer(pos_str): + gr.Info(f"Invalid Frame Position '{pos_str}'") + return ret() + pos = int(pos_str) + if pos <1 or pos > max_source_video_frames: + gr.Info(f"Invalid Frame Position Value'{pos_str}'") + return ret() else: frames_positions = None if audio_source is not None and MMAudio_setting != 0: gr.Info("MMAudio and Custom Audio Soundtrack can't not be used at the same time") - return + return ret() if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0: if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0: gr.Info("The number of frames to keep must be a non null integer") - return + return ret() else: keep_frames_video_source = "" @@ -419,18 +423,18 @@ def process_prompt_and_add_tasks(state, model_choice): if "V" in image_prompt_type: if video_source == None: gr.Info("You must provide a Source Video file to continue") - return + return ret() else: video_source = None if "A" in audio_prompt_type: if audio_guide == None: gr.Info("You must provide an Audio Source") - return + return ret() if "B" in audio_prompt_type: if audio_guide2 == None: gr.Info("You must provide a second Audio Source") - return + return ret() else: audio_guide2 = None else: @@ -444,23 +448,23 @@ def process_prompt_and_add_tasks(state, model_choice): if model_def.get("one_image_ref_needed", False): if image_refs == None : gr.Info("You must provide an Image Reference") - return + return ret() if len(image_refs) > 1: gr.Info("Only one Image Reference (a person) is supported for the moment by this model") - return + return ret() if model_def.get("at_least_one_image_ref_needed", False): if image_refs == None : gr.Info("You must provide at least one Image Reference") - return + return ret() if "I" in video_prompt_type: if image_refs == None or len(image_refs) == 0: gr.Info("You must provide at least one Reference Image") - return + return ret() image_refs = clean_image_list(image_refs) if image_refs == None : gr.Info("A Reference Image should be an Image") - return + return ret() else: image_refs = None @@ -468,35 +472,36 @@ def process_prompt_and_add_tasks(state, model_choice): if image_outputs: if image_guide is None: gr.Info("You must provide a Control Image") - return + return ret() else: if video_guide is None: gr.Info("You must provide a Control Video") - return + return ret() if "A" in video_prompt_type and not "U" in video_prompt_type: if image_outputs: if image_mask is None: gr.Info("You must provide a Image Mask") - return + return ret() else: if video_mask is None: gr.Info("You must provide a Video Mask") - return + return ret() else: video_mask = None image_mask = None if "G" in video_prompt_type: - gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(num_inference_steps * (1. - denoising_strength))} ") + if denoising_strength < 1.: + gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(round(num_inference_steps * (1. - denoising_strength),4))} ") else: denoising_strength = 1.0 if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]: gr.Info("Keep Frames for Control Video is not supported with LTX Video") - return + return ret() _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) if len(error) > 0: gr.Info(f"Invalid Keep Frames property: {error}") - return + return ret() else: video_guide = None image_guide = None @@ -516,14 +521,14 @@ def process_prompt_and_add_tasks(state, model_choice): if "S" in image_prompt_type: if image_start == None or isinstance(image_start, list) and len(image_start) == 0: gr.Info("You must provide a Start Image") - return + return ret() image_start = clean_image_list(image_start) if image_start == None : gr.Info("Start Image should be an Image") - return + return ret() if multi_prompts_gen_type == 1 and len(image_start) > 1: gr.Info("Only one Start Image is supported") - return + return ret() else: image_start = None @@ -532,19 +537,19 @@ def process_prompt_and_add_tasks(state, model_choice): if "E" in image_prompt_type: if image_end == None or isinstance(image_end, list) and len(image_end) == 0: gr.Info("You must provide an End Image") - return + return ret() image_end = clean_image_list(image_end) if image_end == None : gr.Info("End Image should be an Image") - return + return ret() if multi_prompts_gen_type == 0: if video_source is not None: if len(image_end)> 1: gr.Info("If a Video is to be continued and the option 'Each Text Prompt Will create a new generated Video' is set, there can be only one End Image") - return + return ret() elif len(image_start or []) != len(image_end or []): gr.Info("The number of Start and End Images should be the same when the option 'Each Text Prompt Will create a new generated Video'") - return + return ret() else: image_end = None @@ -553,7 +558,7 @@ def process_prompt_and_add_tasks(state, model_choice): if video_length > sliding_window_size: if model_type in ["t2v"] and not "G" in video_prompt_type : gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.") - return + return ret() full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1 extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation" no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) @@ -561,22 +566,22 @@ def process_prompt_and_add_tasks(state, model_choice): if "recam" in model_filename: if video_guide == None: gr.Info("You must provide a Control Video") - return + return ret() computed_fps = get_computed_fps(force_fps, model_type , video_guide, video_source ) frames = get_resampled_video(video_guide, 0, 81, computed_fps) if len(frames)<81: gr.Info(f"Recammaster Control video should be at least 81 frames once the resampling at {computed_fps} fps has been done") - return + return ret() if "hunyuan_custom_custom_edit" in model_filename: if len(keep_frames_video_guide) > 0: gr.Info("Filtering Frames with this model is not supported") - return + return ret() if inputs["multi_prompts_gen_type"] != 0: if image_start != None and len(image_start) > 1: gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows") - return + return ret() # if image_end != None and len(image_end) > 1: # gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") @@ -624,7 +629,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(prompts) >= len(image_start): if len(prompts) % len(image_start) != 0: gr.Info("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - return + return ret() rep = len(prompts) // len(image_start) new_image_start = [] new_image_end = [] @@ -638,7 +643,7 @@ def process_prompt_and_add_tasks(state, model_choice): else: if len(image_start) % len(prompts) !=0: gr.Info("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - return + return ret() rep = len(image_start) // len(prompts) new_prompts = [] for i, _ in enumerate(image_start): @@ -666,11 +671,13 @@ def process_prompt_and_add_tasks(state, model_choice): override_inputs["prompt"] = "\n".join(prompts) inputs.update(override_inputs) add_video_task(**inputs) - - gen["prompts_max"] = new_prompts_count + gen.get("prompts_max",0) + new_prompts_count += gen.get("prompts_max",0) + gen["prompts_max"] = new_prompts_count state["validate_success"] = 1 queue= gen.get("queue", []) - return update_queue_data(queue) + first_time_in_queue = state.get("first_time_in_queue", True) + state["first_time_in_queue"] = True + return update_queue_data(queue, first_time_in_queue), gr.update(open=True) if new_prompts_count > 1 else gr.update() def get_preview_images(inputs): inputs_to_query = ["image_start", "video_source", "image_end", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] @@ -724,7 +731,6 @@ def add_video_task(**inputs): "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, "end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None }) - return update_queue_data(queue) def update_task_thumbnails(task, inputs): start_image_data, end_image_data, start_labels, end_labels = get_preview_images(inputs) @@ -1341,16 +1347,12 @@ def get_queue_table(queue): "✖" ]) return data -def update_queue_data(queue): +def update_queue_data(queue, first_time_in_queue =False): update_global_queue_ref(queue) data = get_queue_table(queue) - if len(data) == 0: - return gr.DataFrame(value=[], max_height=1) - elif len(data) == 1: - return gr.DataFrame(value=data, max_height= 83) - else: - return gr.DataFrame(value=data, max_height= 1000) + return gr.DataFrame(value=data) + def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom" @@ -2004,8 +2006,10 @@ def get_model_recursive_prop(model_type, prop = "URLs", sub_prop_name = None, re raise Exception(f"Unknown model type '{model_type}'") -def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, stack=[]): - if module_type is not None: +def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, URLs = None, stack=[]): + if URLs is not None: + pass + elif module_type is not None: base_model_type = get_base_model_type(model_type) # model_type_handler = model_types_handlers[base_model_type] # modules_files = model_type_handler.query_modules_files() if hasattr(model_type_handler, "query_modules_files") else {} @@ -2032,7 +2036,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}") return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs]) - choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + # choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + choices = URLs if len(quantization) == 0: quantization = "bf16" @@ -2042,13 +2047,13 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu raw_filename = choices[0] else: if quantization in ("int8", "fp8"): - sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name] + sub_choices = [ name for name in choices if quantization in os.path.basename(name) or quantization.upper() in os.path.basename(name)] else: - sub_choices = [ name for name in choices if "quanto" not in name] + sub_choices = [ name for name in choices if "quanto" not in os.path.basename(name)] if len(sub_choices) > 0: dtype_str = "fp16" if dtype == torch.float16 else "bf16" - new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name] + new_sub_choices = [ name for name in sub_choices if dtype_str in os.path.basename(name) or dtype_str.upper() in os.path.basename(name)] sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices raw_filename = sub_choices[0] else: @@ -2107,6 +2112,10 @@ def fix_settings(model_type, ui_defaults): audio_prompt_type ="A" ui_defaults["audio_prompt_type"] = audio_prompt_type + if settings_version < 2.35 and any_audio_track(base_model_type): + audio_prompt_type = audio_prompt_type or "" + audio_prompt_type += "V" + ui_defaults["audio_prompt_type"] = audio_prompt_type video_prompt_type = ui_defaults.get("video_prompt_type", "") @@ -2156,13 +2165,14 @@ def fix_settings(model_type, ui_defaults): video_prompt_type = ui_defaults.get("video_prompt_type", "") image_ref_choices_list = model_def.get("image_ref_choices", {}).get("choices", []) - if len(image_ref_choices_list)==0: - video_prompt_type = del_in_sequence(video_prompt_type, "IK") - else: - first_choice = image_ref_choices_list[0][1] - if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I" - if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K" - ui_defaults["video_prompt_type"] = video_prompt_type + if model_def.get("guide_custom_choices", None) is None: + if len(image_ref_choices_list)==0: + video_prompt_type = del_in_sequence(video_prompt_type, "IK") + else: + first_choice = image_ref_choices_list[0][1] + if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I" + if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K" + ui_defaults["video_prompt_type"] = video_prompt_type model_handler = get_model_handler(base_model_type) if hasattr(model_handler, "fix_settings"): @@ -2200,7 +2210,8 @@ def get_default_prompt(i2v): "slg_switch": 0, "slg_layers": [9], "slg_start_perc": 10, - "slg_end_perc": 90 + "slg_end_perc": 90, + "audio_prompt_type": "V", } model_handler = get_model_handler(model_type) model_handler.update_default_settings(base_model_type, model_def, ui_defaults) @@ -2348,7 +2359,7 @@ def init_model_def(model_type, model_def): lock_ui_compile = True -def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True ): +def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True, module_source_no = 1): model_def = get_model_def(model_type) # To save module and quantized modules # 1) set Transformer Model Quantization Type to 16 bits @@ -2359,10 +2370,10 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod if model_def == None: return if is_module: url_key = "modules" - source_key = "module_source" + source_key = "module_source" if module_source_no <=1 else "module_source2" else: url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no) - source_key = "source" + source_key = "source" if submodel_no <=1 else "source2" URLs= model_def.get(url_key, None) if URLs is None: return if isinstance(URLs, str): @@ -2377,6 +2388,9 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod print("Target Module files are missing") return URLs= URLs[0] + if isinstance(URLs, dict): + url_dict_key = "URLs" if module_source_no ==1 else "URLs2" + URLs = URLs[url_dict_key] for url in URLs: if "quanto" not in url and dtypestr in url: model_filename = os.path.basename(url) @@ -2410,8 +2424,12 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod elif not os.path.isfile(quanto_filename): offload.save_model(model, quanto_filename, config_file_path=config_file, do_quantize= True, filter_sd=filter) print(f"New quantized file '{quanto_filename}' had been created for finetune Id '{model_type}'.") - model_def[url_key][0].append(quanto_filename) - saved_finetune_def["model"][url_key][0].append(quanto_filename) + if isinstance(model_def[url_key][0],dict): + model_def[url_key][0][url_dict_key].append(quanto_filename) + saved_finetune_def["model"][url_key][0][url_dict_key].append(quanto_filename) + else: + model_def[url_key][0].append(quanto_filename) + saved_finetune_def["model"][url_key][0].append(quanto_filename) update_model_def = True if update_model_def: with open(finetune_file, "w", encoding="utf-8") as writer: @@ -2466,6 +2484,14 @@ def preprocessor_wrapper(sd): return preprocessor_wrapper +def get_local_model_filename(model_filename): + if model_filename.startswith("http"): + local_model_filename = os.path.join("ckpts", os.path.basename(model_filename)) + else: + local_model_filename = model_filename + return local_model_filename + + def process_files_def(repoId, sourceFolderList, fileList): targetRoot = "ckpts/" @@ -2491,7 +2517,7 @@ def download_mmaudio(): } process_files_def(**enhancer_def) -def download_models(model_filename = None, model_type= None, module_type = None, submodel_no = 1): +def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1): def computeList(filename): if filename == None: return [] @@ -2506,11 +2532,12 @@ def computeList(filename): shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "det_align", "" ], + "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "roformer", "pyannote", "det_align", "" ], "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], ["config.json", "pytorch_model.bin", "preprocessor_config.json"], + ["model_bs_roformer_ep_317_sdr_12.9755.ckpt", "model_bs_roformer_ep_317_sdr_12.9755.yaml", "download_checks.json"], ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], ["detface.pt"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2558,37 +2585,25 @@ def download_file(url,filename): base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - source = model_def.get("source", None) - module_source = model_def.get("module_source", None) + any_source = ("source2" if submodel_no ==2 else "source") in model_def + any_module_source = ("module_source2" if submodel_no ==2 else "module_source") in model_def model_type_handler = model_types_handlers[base_model_type] - - if source is not None and module_type is None or module_source is not None and module_type is not None: + local_model_filename = get_local_model_filename(model_filename) + + if any_source and not module_type or any_module_source and module_type: model_filename = None else: - if not os.path.isfile(model_filename): - if module_type is not None: - key_name = "modules" - URLs = module_type - if isinstance(module_type, str): - URLs = get_model_recursive_prop(module_type, key_name, sub_prop_name="_list", return_list= False) - else: - key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" - URLs = get_model_recursive_prop(model_type, key_name, return_list= False) - if isinstance(URLs, str): - raise Exception("Missing model " + URLs) - use_url = model_filename - for url in URLs: - if os.path.basename(model_filename) in url: - use_url = url - break + if not os.path.isfile(local_model_filename): + url = model_filename + if not url.startswith("http"): - raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") + raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") try: - download_file(use_url, model_filename) + download_file(url, local_model_filename) except Exception as e: - if os.path.isfile(model_filename): os.remove(model_filename) - raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") - + if os.path.isfile(local_model_filename): os.remove(local_model_filename) + raise Exception(f"'{url}' is invalid for Model '{local_model_filename}' : {str(e)}'") + if module_type: return model_filename = None preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) @@ -2614,6 +2629,7 @@ def download_file(url,filename): except Exception as e: if os.path.isfile(filename): os.remove(filename) raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'") + if module_type: return model_files = model_type_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) if not isinstance(model_files, list): model_files = [model_files] for one_repo in model_files: @@ -2800,7 +2816,8 @@ def load_models(model_type, override_profile = -1): model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!! else: model_filename2 = None - modules = get_model_recursive_prop(model_type, "modules", return_list= True) + modules = get_model_recursive_prop(model_type, "modules", return_list= True) + modules = [get_model_recursive_prop(module, "modules", sub_prop_name ="_list", return_list= True) if isinstance(module, str) else module for module in modules ] if save_quantized and "quanto" in model_filename: save_quantized = False print("Need to provide a non quantized model to create a quantized model to be saved") @@ -2825,27 +2842,40 @@ def load_models(model_type, override_profile = -1): if model_filename2 != None: model_file_list += [model_filename2] model_type_list += [model_type] - module_type_list += [None] + module_type_list += [False] model_submodel_no_list += [2] for module_type in modules: - model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) - model_type_list.append(model_type) - module_type_list.append(module_type) - model_submodel_no_list.append(0) + if isinstance(module_type,dict): + URLs1 = module_type.get("URLs", None) + if URLs1 is None: raise Exception(f"No URLs defined for Module {module_type}") + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs1)) + URLs2 = module_type.get("URLs2", None) + if URLs2 is None: raise Exception(f"No URL2s defined for Module {module_type}") + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs2)) + model_type_list += [model_type] * 2 + module_type_list += [True] * 2 + model_submodel_no_list += [1,2] + else: + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) + model_type_list.append(model_type) + module_type_list.append(True) + model_submodel_no_list.append(0) + local_model_file_list= [] for filename, file_model_type, file_module_type, submodel_no in zip(model_file_list, model_type_list, module_type_list, model_submodel_no_list): download_models(filename, file_model_type, file_module_type, submodel_no) + local_model_file_list.append( get_local_model_filename(filename) ) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_type = None - for submodel_no, filename in zip(model_submodel_no_list, model_file_list): - if submodel_no>=1: + for module_type, filename in zip(module_type_list, local_model_file_list): + if module_type is None: print(f"Loading Model '{filename}' ...") else: print(f"Loading Module '{filename}' ...") wan_model, pipe = model_types_handlers[base_model_type].load_model( - model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, - dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + local_model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, + dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized, submodel_no_list = model_submodel_no_list, ) kwargs = {} profile = init_pipe(pipe, kwargs, override_profile) @@ -2883,7 +2913,7 @@ def generate_header(model_type, compile, attention_mode): description_container = [""] get_model_name(model_type, description_container) - model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or "" + model_filename = os.path.basename(get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)) or "" description = description_container[0] header = f"
{description}
" overridden_attention = get_overridden_attention(model_type) @@ -3517,6 +3547,48 @@ def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='t # print(f"frame nos: {frame_nos}") return frames_list +# def get_resampled_video(video_in, start_frame, max_frames, target_fps): +# from torchvision.io import VideoReader +# import torch +# from shared.utils.utils import resample + +# vr = VideoReader(video_in, "video") +# meta = vr.get_metadata()["video"] + +# fps = round(float(meta["fps"][0])) +# duration_s = float(meta["duration"][0]) +# num_src_frames = int(round(duration_s * fps)) # robust length estimate + +# if max_frames < 0: +# max_frames = max(int(num_src_frames / fps * target_fps + max_frames), 0) + +# frame_nos = resample( +# fps, num_src_frames, +# max_target_frames_count=max_frames, +# target_fps=target_fps, +# start_target_frame=start_frame +# ) +# if len(frame_nos) == 0: +# return torch.empty((0,)) # nothing to return + +# target_ts = [i / fps for i in frame_nos] + +# # Read forward once, grabbing frames when we pass each target timestamp +# frames = [] +# vr.seek(target_ts[0]) +# idx = 0 +# tol = 0.5 / fps # half-frame tolerance +# for frame in vr: +# t = float(frame["pts"]) # seconds +# if idx < len(target_ts) and t + tol >= target_ts[idx]: +# frames.append(frame["data"].permute(1,2,0)) # Tensor [H, W, C] +# idx += 1 +# if idx >= len(target_ts): +# break + +# return frames + + def get_preprocessor(process_type, inpaint_color): if process_type=="pose": from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator @@ -3605,11 +3677,29 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis return results -def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, block_size= 16, expand_scale = 2): +def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127): frame_width, frame_height = input_image.size + if fit_crop: + input_image = rescale_and_crop(input_image, width, height) + if input_mask is not None: + input_mask = rescale_and_crop(input_mask, width, height) + return input_image, input_mask + + if outpainting_dims != None: + if fit_canvas != None: + frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims) + else: + frame_height, frame_width = height, width + if fit_canvas != None: height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size) + + if outpainting_dims != None: + final_height, final_width = height, width + height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) + + if fit_canvas != None or outpainting_dims != None: input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS) if input_mask is not None: input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS) @@ -3621,10 +3711,23 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva input_mask = np.array(input_mask) input_mask = op_expand(input_mask, kernel, iterations=3) input_mask = Image.fromarray(input_mask) + + if outpainting_dims != None: + inpaint_color = inpaint_color / 127.5-1 + image = convert_image_to_tensor(input_image) + full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device) + full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image + input_image = convert_tensor_to_image(full_frame) + + if input_mask is not None: + mask = convert_image_to_tensor(input_mask) + full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device) + full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask + input_mask = convert_tensor_to_image(full_frame) + return input_image, input_mask def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): - from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) @@ -4592,6 +4695,7 @@ def remove_temp_filenames(temp_filenames_list): reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) else: sliding_window = False + sliding_window_size = current_video_length reuse_frames = 0 _, latent_size = get_model_min_frames_and_step(model_type) @@ -4603,10 +4707,6 @@ def remove_temp_filenames(temp_filenames_list): # Source Video or Start Image > Control Video > Image Ref (background or positioned frames only) > UI Width, Height # Image Ref (non background and non positioned frames) are boxed in a white canvas in order to keep their own width/height ratio frames_to_inject = [] - if image_refs is not None: - frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions is not None and len(frames_positions)> 0 else [] - frames_positions_list = frames_positions_list[:len(image_refs)] - nb_frames_positions = len(frames_positions_list) any_background_ref = 0 if "K" in video_prompt_type: any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1 @@ -4642,29 +4742,47 @@ def remove_temp_filenames(temp_filenames_list): output_new_audio_data = None output_new_audio_filepath = None original_audio_guide = audio_guide + original_audio_guide2 = audio_guide2 audio_proj_split = None audio_proj_full = None audio_scale = None audio_context_lens = None if (fantasy or multitalk or hunyuan_avatar or hunyuan_custom_audio) and audio_guide != None: from models.wan.fantasytalking.infer import parse_audio + from preprocessing.extract_vocals import get_vocals import librosa duration = librosa.get_duration(path=audio_guide) combination_type = "add" + clean_audio_files = "V" in audio_prompt_type if audio_guide2 is not None: duration2 = librosa.get_duration(path=audio_guide2) if "C" in audio_prompt_type: duration += duration2 else: duration = min(duration, duration2) combination_type = "para" if "P" in audio_prompt_type else "add" + if clean_audio_files: + audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav")) + audio_guide2 = get_vocals(original_audio_guide2, get_available_filename(save_path, audio_guide2, "_clean2", ".wav")) + temp_filenames_list += [audio_guide, audio_guide2] else: if "X" in audio_prompt_type: + # dual speaker, voice separation from preprocessing.speakers_separator import extract_dual_audio combination_type = "para" if args.save_speakers: audio_guide, audio_guide2 = "speaker1.wav", "speaker2.wav" else: audio_guide, audio_guide2 = get_available_filename(save_path, audio_guide, "_tmp1", ".wav"), get_available_filename(save_path, audio_guide, "_tmp2", ".wav") - extract_dual_audio(original_audio_guide, audio_guide, audio_guide2 ) + temp_filenames_list += [audio_guide, audio_guide2] + if clean_audio_files: + clean_audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, original_audio_guide, "_clean", ".wav")) + temp_filenames_list += [clean_audio_guide] + extract_dual_audio(clean_audio_guide if clean_audio_files else original_audio_guide, audio_guide, audio_guide2) + + elif clean_audio_files: + # Single Speaker + audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav")) + temp_filenames_list += [audio_guide] + output_new_audio_filepath = original_audio_guide current_video_length = min(int(fps * duration //latent_size) * latent_size + latent_size + 1, current_video_length) @@ -4676,10 +4794,10 @@ def remove_temp_filenames(temp_filenames_list): # pad audio_proj_full if aligned to beginning of window to simulate source window overlap min_audio_duration = current_video_length/fps if reset_control_aligment else video_source_duration + current_video_length/fps audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration) - if output_new_audio_data is not None: output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined - if not args.save_speakers and "X" in audio_prompt_type: - os.remove(audio_guide) - os.remove(audio_guide2) + if output_new_audio_data is not None: # not none if modified + if clean_audio_files: # need to rebuild the sum of audios with original audio + _, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = original_audio_guide, audio_guide2= original_audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration, return_sum_only= True) + output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined if hunyuan_custom_edit and video_guide != None: import cv2 @@ -4687,8 +4805,6 @@ def remove_temp_filenames(temp_filenames_list): length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) current_video_length = min(current_video_length, length) - if image_guide is not None: - image_guide, image_mask = preprocess_image_with_mask(image_guide, image_mask, height, width, fit_canvas = None, block_size= block_size, expand_scale = mask_expand) seed = set_seed(seed) @@ -4727,7 +4843,7 @@ def remove_temp_filenames(temp_filenames_list): if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video, src_mask, src_ref_images = None, None, None + src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4899,7 +5015,7 @@ def remove_temp_filenames(temp_filenames_list): elif "R" in video_prompt_type: # sparse video to video src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True) - src_image, _, _ = calculate_dimensions_and_resize_image(src_image, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size) + src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size) refresh_preview["video_guide"] = src_image src_video = convert_image_to_tensor(src_image).unsqueeze(1) if sample_fit_canvas != None: @@ -4918,15 +5034,38 @@ def remove_temp_filenames(temp_filenames_list): if pre_video_guide != None: src_video = torch.cat( [pre_video_guide, src_video], dim=1) elif image_guide is not None: - image_guide, new_height, new_width = calculate_dimensions_and_resize_image(image_guide, height, width, sample_fit_canvas, fit_crop, block_size = block_size) - image_size = (new_height, new_width) - refresh_preview["image_guide"] = image_guide - sample_fit_canvas = None - if image_mask is not None: - image_mask, _, _ = calculate_dimensions_and_resize_image(image_mask, new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size) - refresh_preview["image_mask"] = image_mask + new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims) + if sample_fit_canvas is not None: + image_size = (new_image_guide.size[1], new_image_guide.size[0]) + sample_fit_canvas = None + refresh_preview["image_guide"] = new_image_guide + if new_image_mask is not None: + refresh_preview["image_mask"] = new_image_mask if window_no == 1 and image_refs is not None and len(image_refs) > 0: + if repeat_no == 1: + frames_positions_list = [] + if frames_positions is not None and len(frames_positions)> 0: + positions = frames_positions.split(" ") + cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0 + last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count + joker_used = False + project_window_no = 1 + for pos in positions : + if len(pos) > 0: + if pos in ["L", "l"]: + cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length + if cur_end_pos >= last_frame_no and not joker_used: + joker_used = True + cur_end_pos = last_frame_no -1 + project_window_no += 1 + frames_positions_list.append(cur_end_pos) + cur_end_pos -= sliding_window_discard_last_frames + reuse_frames + else: + frames_positions_list.append(int(pos)-1 + alignment_shift) + frames_positions_list = frames_positions_list[:len(image_refs)] + nb_frames_positions = len(frames_positions_list) + if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : from shared.utils.utils import get_outpainting_full_area_dimensions w, h = image_refs[0].size @@ -4963,7 +5102,7 @@ def remove_temp_filenames(temp_filenames_list): frames_to_inject[pos] = image_refs[i] if vace : - frames_to_inject_parsed = frames_to_inject[aligned_guide_start_frame: aligned_guide_end_frame] + frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame] image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], @@ -4971,14 +5110,13 @@ def remove_temp_filenames(temp_filenames_list): [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], current_video_length, image_size = image_size, device ="cpu", keep_video_guide_frames=keep_frames_parsed, - start_frame = aligned_guide_start_frame, pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], inject_frames= frames_to_inject_parsed, outpainting_dims = outpainting_dims, any_background_ref = any_background_ref ) if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] if any_background_ref: new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] else: @@ -5094,8 +5232,9 @@ def set_header_text(txt): pre_video_frame = pre_video_frame, original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, - image_guide= image_guide, - image_mask= image_mask, + image_guide= new_image_guide, + image_mask= new_image_mask, + outpainting_dims = outpainting_dims, ) except Exception as e: if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: @@ -6755,10 +6894,17 @@ def refresh_audio_prompt_type_remux(state, audio_prompt_type, remux): audio_prompt_type = add_to_sequence(audio_prompt_type, remux) return audio_prompt_type +def refresh_remove_background_sound(state, audio_prompt_type, remove_background_sound): + audio_prompt_type = del_in_sequence(audio_prompt_type, "V") + if remove_background_sound: + audio_prompt_type = add_to_sequence(audio_prompt_type, "V") + return audio_prompt_type + + def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources): audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPAB") audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources) - return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)) + return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)), gr.update(visible= any_letters(audio_prompt_type, "ABX")) def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_radio): image_prompt_type = del_in_sequence(image_prompt_type, "VLTS") @@ -6775,7 +6921,7 @@ def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) return image_prompt_type, gr.update(visible = "E" in image_prompt_type ) -def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs): +def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs, image_mode): model_type = state["model_type"] model_def = get_model_def(model_type) image_ref_choices = model_def.get("image_ref_choices", None) @@ -6785,11 +6931,10 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_ video_prompt_type = del_in_sequence(video_prompt_type, "KFI") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs) visible = "I" in video_prompt_type - vace= test_vace_module(state["model_type"]) - + any_outpainting= image_mode in model_def.get("video_guide_outpainting", []) rm_bg_visible= visible and not model_def.get("no_background_removal", False) img_rel_size_visible = visible and model_def.get("any_image_refs_relative_size", False) - return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace ) + return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and any_outpainting ) def switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): if image_mode == 0: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) @@ -6840,13 +6985,13 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt visible = "V" in video_prompt_type model_type = state["model_type"] model_def = get_model_def(model_type) - vace= test_vace_module(model_type) + any_outpainting= image_mode in model_def.get("video_guide_outpainting", []) mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type image_outputs = image_mode > 0 keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) - return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible) + return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible) def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode): @@ -7112,6 +7257,12 @@ def refresh_video_length_label(state, current_video_length, force_fps, video_gui computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source ) return gr.update(label= compute_video_length_label(computed_fps, current_video_length)) +def get_default_value(choices, current_value, default_value = None): + for label, value in choices: + if value == current_value: + return current_value + return default_value + def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): global inputs_names #, advanced @@ -7277,7 +7428,10 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") model_mode_choices = model_def.get("model_modes", None) - with gr.Column(visible= image_mode_value == 0 and (len(image_prompt_types_allowed)> 0 or model_mode_choices is not None)) as image_prompt_column: + model_modes_visibility = [0,1,2] + if model_mode_choices is not None: model_modes_visibility= model_mode_choices.get("image_modes", model_modes_visibility) + + with gr.Column(visible= image_mode_value == 0 and len(image_prompt_types_allowed)> 0 or model_mode_choices is not None and image_mode_value in model_modes_visibility ) as image_prompt_column: # Video Continue / Start Frame / End Frame image_prompt_type_value= ui_defaults.get("image_prompt_type","") image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) @@ -7293,7 +7447,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl if "L" in image_prompt_types_allowed: any_video_source = True image_prompt_type_choices += [("Continue Last Video", "L")] - with gr.Group(visible= len(image_prompt_types_allowed)>1) as image_prompt_type_group: + with gr.Group(visible= len(image_prompt_types_allowed)>1 and image_mode_value == 0) as image_prompt_type_group: with gr.Row(): image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL") image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1] if len(image_prompt_type_choices) > 0 else "") @@ -7309,10 +7463,11 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new Videos in the Generation Queue", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) image_end_row, image_end, image_end_extra = get_image_gallery(label= get_image_end_label(ui_defaults.get("multi_prompts_gen_type", 0)), value = ui_defaults.get("image_end", None), visible= any_letters(image_prompt_type_value, "SVL") and ("E" in image_prompt_type_value) ) - if model_mode_choices is None: + if model_mode_choices is None or image_mode_value not in model_modes_visibility: model_mode = gr.Dropdown(value=None, visible=False) else: - model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=ui_defaults.get("model_mode", model_mode_choices["default"]), label=model_mode_choices["label"], visible=True) + model_mode_value = get_default_value(model_mode_choices["choices"], ui_defaults.get("model_mode", None), model_mode_choices["default"] ) + model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=model_mode_value, label=model_mode_choices["label"], visible=True) keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) any_control_video = any_control_image = False @@ -7374,9 +7529,11 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process") if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image") video_prompt_type_video_guide_alt_choices = [(label.replace("Video", "Image") if image_outputs else label, value) for label,value in guide_custom_choices["choices"] ] + guide_custom_choices_value = get_default_value(video_prompt_type_video_guide_alt_choices, filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"]), guide_custom_choices.get("default", "") ) video_prompt_type_video_guide_alt = gr.Dropdown( choices= video_prompt_type_video_guide_alt_choices, - value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ), + # value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ), + value=guide_custom_choices_value, visible = guide_custom_choices.get("visible", True), label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = 2 ) @@ -7437,7 +7594,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2 ) - image_guide = gr.Image(label= "Control Image", height = 800, width=800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) + image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) if image_mode_value >= 1: image_guide_value = ui_defaults.get("image_guide", None) @@ -7457,7 +7614,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl layers=False, brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"), # fixed_canvas= True, - width=800, + # width=800, height=800, # transforms=None, # interactive=True, @@ -7472,12 +7629,12 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label=f"Denoising Strength (the Lower the Closer to the Control {'Image' if image_outputs else 'Video'})", visible = "G" in video_prompt_type_value, show_reset_button= False) keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False) keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last - - with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col: + video_guide_outpainting_modes = model_def.get("video_guide_outpainting", []) + with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and image_mode_value in video_guide_outpainting_modes) as video_guide_outpainting_col: video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -7495,7 +7652,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "") image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) - frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) + frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None @@ -7522,7 +7679,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl speaker_choices=[("None", "")] if any_single_speaker: speaker_choices += [("One Person Speaking Only", "A")] if any_multi_speakers:speaker_choices += [ - ("Two speakers, Auto Separation of Speakers (will work only if there is little background noise)", "XA"), + ("Two speakers, Auto Separation of Speakers (will work only if Voices are distinct)", "XA"), ("Two speakers, Speakers Audio sources are assumed to be played in a Row", "CAB"), ("Two speakers, Speakers Audio sources are assumed to be played in Parallel", "PAB") ] @@ -7537,6 +7694,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl with gr.Row(visible = any_audio_voices_support and not image_outputs) as audio_guide_row: audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= any_audio_voices_support and "A" in audio_prompt_type_value ) audio_guide2 = gr.Audio(value= ui_defaults.get("audio_guide2", None), type="filepath", label="Voice to follow #2", show_download_button= True, visible= any_audio_voices_support and "B" in audio_prompt_type_value ) + remove_background_sound = gr.Checkbox(label="Video Motion ignores Background Music (to get a better LipSync)", value="V" in audio_prompt_type_value, visible = any_audio_voices_support and any_letters(audio_prompt_type_value, "ABX") and not image_outputs) with gr.Row(visible = any_audio_voices_support and ("B" in audio_prompt_type_value or "X" in audio_prompt_type_value) and not image_outputs ) as speakers_locations_row: speakers_locations = gr.Text( ui_defaults.get("speakers_locations", "0:45 55:100"), label="Speakers Locations separated by a Space. Each Location = Left:Right or a BBox Left:Top:Right:Bottom", visible= True) @@ -7916,7 +8074,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai sliding_window_defaults = model_def.get("sliding_window_defaults", {}) sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size") sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") - sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True) + sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",0), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True) sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7926,7 +8084,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai ("Aligned to the beginning of the First Window of the new Video Sample", "T"), ], value=filter_letters(video_prompt_type_value, "T"), - label="Control Video / Control Audio temporal alignment when any Source Video", + label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue", visible = vace or ltxv or t2v or infinitetalk ) @@ -8112,7 +8270,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, - NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs, + NAG_col, remove_background_sound , speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs, min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn, tab_inpaint, tab_t2v] + image_start_extra + image_end_extra + image_refs_extra # presets_column, if update_form: locals_dict = locals() @@ -8131,11 +8289,11 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last", show_progress="hidden" ) guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) - audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) + remove_background_sound.change(fn=refresh_remove_background_sound, inputs=[state, audio_prompt_type, remove_background_sound], outputs=[audio_prompt_type]) + audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row, remove_background_sound]) image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" ) image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] ) - # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) - video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden") + video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden") video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand], show_progress="hidden") video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength ], show_progress="hidden") video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden") @@ -8338,7 +8496,7 @@ def activate_status(state): outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs= queue_df, + outputs= [queue_df, queue_accordion], show_progress="hidden", ).then(fn=prepare_generate_video, inputs= [state], @@ -8346,11 +8504,6 @@ def activate_status(state): ).then(fn=activate_status, inputs= [state], outputs= [status_trigger], - ).then( - fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), - inputs=[state], - outputs=[queue_accordion], - show_progress="hidden", ).then(fn=process_tasks, inputs= [state], outputs= [preview_trigger, output_trigger], @@ -8468,12 +8621,7 @@ def activate_status(state): outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs=queue_df, - show_progress="hidden", - ).then( - fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), - inputs=[state], - outputs=[queue_accordion], + outputs=[queue_df, queue_accordion], show_progress="hidden", ).then( fn=update_status, From ae93d71217c080c73c17a7ee5fd04f5d5dc65d93 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 16 Sep 2025 22:38:45 +0200 Subject: [PATCH 027/155] fixed simple annoying bugs --- README.md | 6 +++--- models/qwen/qwen_main.py | 1 + models/wan/any2video.py | 2 +- wgp.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 2411cea3d..8c5e4367a 100644 --- a/README.md +++ b/README.md @@ -24,12 +24,12 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models - The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) -- **First Frame / Last Frame for Vace** : Vace model are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias : +- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias : For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on .... -If you *Continue a Video* , you just need *"L L L"* since the the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*. +If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*. -- **Qwen Inpainting** exist now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area. +- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area. - **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video. diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index da84e0d0e..cec0b8e05 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -217,6 +217,7 @@ def generate( def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode, **kwargs): if model_mode == 0: return [], [] preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") + if len(preloadURLs) == 0: return [], [] return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] diff --git a/models/wan/any2video.py b/models/wan/any2video.py index dde1a6535..3e85f7963 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -146,7 +146,7 @@ def __init__( from mmgp.safetensors2 import torch_load_file else: if self.transformer_switch: - if 0 in submodel_no_list[2:] and 1 in submodel_no_list: + if 0 in submodel_no_list[2:] and 1 in submodel_no_list[2:]: raise Exception("Shared and non shared modules at the same time across multipe models is not supported") if 0 in submodel_no_list[2:]: diff --git a/wgp.py b/wgp.py index de8e2f66f..ee26376b4 100644 --- a/wgp.py +++ b/wgp.py @@ -62,7 +62,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.6" +WanGP_version = "8.61" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None From fc615ffb3cb063bd4579ee98b2f81e7ca6633971 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 16 Sep 2025 23:01:54 +0200 Subject: [PATCH 028/155] fixed simple annoying bugs --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index ee26376b4..ee6b5f62e 100644 --- a/wgp.py +++ b/wgp.py @@ -356,7 +356,7 @@ def ret(): outpainting_dims = get_outpainting_dims(video_guide_outpainting) - if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None: + if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting") if len(loras_multipliers) > 0: From 84010bd861545fd2d8365aaece6b4017d40da9aa Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 22 Sep 2025 17:11:25 +0200 Subject: [PATCH 029/155] commit in case there is an unrecoverable code hemorragy --- configs/animate.json | 15 + configs/lucy_edit.json | 14 + defaults/animate.json | 13 + defaults/lucy_edit.json | 18 ++ defaults/ti2v_2_2_fastwan.json | 1 + models/qwen/pipeline_qwenimage.py | 146 +++++---- models/wan/any2video.py | 193 +++++------ models/wan/modules/model.py | 53 ++- models/wan/wan_handler.py | 81 +++-- shared/convert/convert_diffusers_to_flux.py | 342 ++++++++++++++++++++ shared/inpainting/__init__.py | 0 shared/inpainting/lanpaint.py | 240 ++++++++++++++ shared/inpainting/utils.py | 301 +++++++++++++++++ shared/utils/audio_video.py | 3 + shared/utils/download.py | 110 +++++++ shared/utils/utils.py | 246 ++++++-------- wgp.py | 230 +++++++++---- 17 files changed, 1623 insertions(+), 383 deletions(-) create mode 100644 configs/animate.json create mode 100644 configs/lucy_edit.json create mode 100644 defaults/animate.json create mode 100644 defaults/lucy_edit.json create mode 100644 shared/convert/convert_diffusers_to_flux.py create mode 100644 shared/inpainting/__init__.py create mode 100644 shared/inpainting/lanpaint.py create mode 100644 shared/inpainting/utils.py create mode 100644 shared/utils/download.py diff --git a/configs/animate.json b/configs/animate.json new file mode 100644 index 000000000..7e98ca95e --- /dev/null +++ b/configs/animate.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "motion_encoder_dim": 512 +} \ No newline at end of file diff --git a/configs/lucy_edit.json b/configs/lucy_edit.json new file mode 100644 index 000000000..4983cedd1 --- /dev/null +++ b/configs/lucy_edit.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 3072, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_dim": 96, + "model_type": "ti2v2_2", + "num_heads": 24, + "num_layers": 30, + "out_dim": 48, + "text_len": 512 +} diff --git a/defaults/animate.json b/defaults/animate.json new file mode 100644 index 000000000..fee26d76a --- /dev/null +++ b/defaults/animate.json @@ -0,0 +1,13 @@ +{ + "model": { + "name": "Wan2.2 Animate", + "architecture": "animate", + "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors" + ], + "group": "wan2_2" + } +} \ No newline at end of file diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json new file mode 100644 index 000000000..6344dfffd --- /dev/null +++ b/defaults/lucy_edit.json @@ -0,0 +1,18 @@ +{ + "model": { + "name": "Wan2.2 Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "video_length": 81, + "guidance_scale": 5, + "flow_shift": 5, + "num_inference_steps": 30, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json index 064c2b4ba..fa69f82ec 100644 --- a/defaults/ti2v_2_2_fastwan.json +++ b/defaults/ti2v_2_2_fastwan.json @@ -7,6 +7,7 @@ "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], "group": "wan2_2" }, + "prompt" : "Put the person into a clown outfit.", "video_length": 121, "guidance_scale": 1, "flow_shift": 3, diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index be982aabc..14728862d 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mmgp import offload import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -387,7 +386,8 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): + def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) @@ -479,7 +479,7 @@ def prepare_latents( height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, 1, num_channels_latents, height, width) + shape = (batch_size, num_channels_latents, 1, height, width) image_latents = None if image is not None: @@ -499,10 +499,7 @@ def prepare_latents( else: image_latents = torch.cat([image_latents], dim=0) - image_latent_height, image_latent_width = image_latents.shape[3:] - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width - ) + image_latents = self._pack_latents(image_latents) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -511,7 +508,7 @@ def prepare_latents( ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents) else: latents = latents.to(device=device, dtype=dtype) @@ -713,11 +710,12 @@ def __call__( image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of) # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) height, width = image_height, image_width - image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 16, height // 16), resample=Image.Resampling.LANCZOS)) + image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS)) image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] - image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) + image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") - image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) + image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1) + image_mask_latents = self._pack_latents(image_mask_latents) prompt_image = image if image.size != (image_width, image_height): @@ -822,6 +820,7 @@ def __call__( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) morph, first_step = False, 0 + lanpaint_proc = None if image_mask_latents is not None: randn = torch.randn_like(original_image_latents) if denoising_strength < 1.: @@ -833,7 +832,8 @@ def __call__( timesteps = timesteps[first_step:] self.scheduler.timesteps = timesteps self.scheduler.sigmas= self.scheduler.sigmas[first_step:] - + # from shared.inpainting.lanpaint import LanPaint + # lanpaint_proc = LanPaint() # 6. Denoising loop self.scheduler.set_begin_index(0) updated_num_steps= len(timesteps) @@ -847,48 +847,52 @@ def __call__( offload.set_step_no_for_lora(self.transformer, first_step + i) if self.interrupt: continue + self._current_timestep = t + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph: latent_noise_factor = t/1000 latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor - self._current_timestep = t - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - if do_true_cfg and joint_pass: - noise_pred, neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - ) - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - else: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - )[0] - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] + latents_dtype = latents.dtype + + # latent_model_input = latents + def denoise(latent_model_input, true_cfg_scale): + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + do_true_cfg = true_cfg_scale > 1 + if do_true_cfg and joint_pass: + noise_pred, neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, #!!!! + encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + ) + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + else: + neg_noise_pred = None + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] if do_true_cfg: neg_noise_pred = self.transformer( @@ -902,27 +906,43 @@ def __call__( attention_kwargs=self.attention_kwargs, **kwargs )[0] - if neg_noise_pred == None: return None + if neg_noise_pred == None: return None, None neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + return noise_pred, neg_noise_pred + def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): + if do_true_cfg: + comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred) + if comb_pred == None: return None + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) - if do_true_cfg: - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if comb_pred == None: return None + return noise_pred - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - neg_noise_pred = None + + if lanpaint_proc is not None and i<=3: + latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8) + if latents is None: return None + + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None + noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + neg_noise_pred = None # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + noise_pred = None + if image_mask_latents is not None: - next_t = timesteps[i+1] if i0: + msk[:, :nb_frames_unchanged] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1,2)[0] + return msk + def generate(self, input_prompt, input_frames= None, input_masks = None, - input_ref_images = None, + input_ref_images = None, + input_ref_masks = None, + input_faces = None, input_video = None, image_start = None, image_end = None, @@ -541,14 +506,18 @@ def generate(self, infinitetalk = model_type in ["infinitetalk"] standin = model_type in ["standin", "vace_standin_14B"] recam = model_type in ["recam_1.3B"] - ti2v = model_type in ["ti2v_2_2"] + ti2v = model_type in ["ti2v_2_2", "lucy_edit"] + lucy_edit= model_type in ["lucy_edit"] + animate= model_type in ["animate"] start_step_no = 0 ref_images_count = 0 trim_frames = 0 - extended_overlapped_latents = None + extended_overlapped_latents = clip_image_start = clip_image_end = None no_noise_latents_injection = infinitetalk timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + extended_input_dim = 0 + ref_images_before = False # image2video if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False @@ -598,17 +567,7 @@ def generate(self, if image_end is not None: img_end_frame = image_end.unsqueeze(1).to(self.device) - - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) if image_end is not None else image_start - if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) - else: - clip_context = self.clip.visual([image_start[:, None, :, :]]) - else: - clip_context = None + clip_image_start, clip_image_end = image_start, image_end if any_end_frame: enc= torch.concat([ @@ -647,21 +606,62 @@ def generate(self, if infinitetalk: lat_y = self.vae.encode([input_video], VAE_tile_size)[0] extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) - # if control_pre_frames_count != pre_frames_count: lat_y = input_video = None kwargs.update({ 'y': y}) - if not clip_context is None: - kwargs.update({'clip_fea': clip_context}) - # Recam Master - if recam: - target_camera = model_mode - height,width = input_frames.shape[-2:] - input_frames = input_frames.to(dtype=self.dtype , device=self.device) - source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + # Animate + if animate: + pose_pixels = input_frames * input_masks + input_masks = 1. - input_masks + pose_pixels -= input_masks + save_video(pose_pixels, "pose.mp4") + pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) + input_frames = input_frames * input_masks + if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames + if prefix_frames_count > 0: + input_frames[:, :prefix_frames_count] = input_video + input_masks[:, :prefix_frames_count] = 1 + save_video(input_frames, "input_frames.mp4") + save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) + msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) + msk = torch.concat([msk_ref, msk_control], dim=1) + clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device) + lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1) + y = torch.concat([msk, lat_y]) + kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) + lat_y = msk = msk_control = msk_ref = pose_pixels = None + ref_images_before = True + ref_images_count = 1 + lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1 + + # Clip image + if hasattr(self, "clip") and clip_image_start is not None: + clip_image_size = self.clip.model.image_size + clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size) + clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start + if model_type == "flf2v_720p": + clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]]) + else: + clip_context = self.clip.visual([clip_image_start[:, None, :, :]]) + clip_image_start = clip_image_end = None + kwargs.update({'clip_fea': clip_context}) + + # Recam Master & Lucy Edit + if recam or lucy_edit: + frame_num, height,width = input_frames.shape[-3:] + lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + frame_num = (lat_frames -1) * self.vae_stride[0] + 1 + input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device) + extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + extended_input_dim = 2 if recam else 1 del input_frames + + if recam: # Process target camera (recammaster) + target_camera = model_mode from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) @@ -715,6 +715,8 @@ def generate(self, height, width = input_video.shape[-2:] source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) timestep_injection = True + if extended_input_dim > 0: + extended_latents[:, :, :source_latents.shape[2]] = source_latents # Vace if vace : @@ -722,6 +724,7 @@ def generate(self, input_frames = [u.to(self.device) for u in input_frames] input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] input_masks = [u.to(self.device) for u in input_masks] + ref_images_before = True if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) @@ -771,9 +774,9 @@ def generate(self, expand_shape = [batch_size] + [-1] * len(target_shape) # Ropes - if target_camera != None: + if extended_input_dim>=2: shape = list(target_shape[1:]) - shape[0] *= 2 + shape[extended_input_dim-2] *= 2 freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) else: freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) @@ -901,8 +904,8 @@ def clear(): for zz in z: zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor - if target_camera != None: - latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) + if extended_input_dim > 0: + latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim) else: latent_model_input = latents @@ -1030,7 +1033,7 @@ def clear(): if callback is not None: latents_preview = latents - if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] + if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if image_outputs: latents_preview= latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) @@ -1041,7 +1044,7 @@ def clear(): if timestep_injection: latents[:, :, :source_latents.shape[2]] = source_latents - if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] + if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if trim_frames > 0: latents= latents[:, :,:-trim_frames] if return_latent_slice != None: latent_slice = latents[:, :, return_latent_slice].clone() @@ -1078,4 +1081,12 @@ def adapt_vace_model(self, model): delattr(model, "vace_blocks") + def adapt_animate_model(self, model): + modules_dict= { k: m for k, m in model.named_modules()} + for animate_layer in range(8): + module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"] + model_layer = animate_layer * 5 + target = modules_dict[f"blocks.{model_layer}"] + setattr(target, "face_adapter_fuser_blocks", module ) + delattr(model, "face_adapter") diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index d3dc783e3..c98087b5d 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -16,6 +16,9 @@ from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel from ..multitalk.multitalk_utils import get_attn_map_with_target +from ..animate.motion_encoder import Generator +from ..animate.face_blocks import FaceAdapter, FaceEncoder +from ..animate.model_animate import after_patch_embedding __all__ = ['WanModel'] @@ -499,6 +502,7 @@ def forward( multitalk_masks=None, ref_images_count=0, standin_phase=-1, + motion_vec = None, ): r""" Args: @@ -616,6 +620,10 @@ def forward( x.add_(hint) else: x.add_(hint, alpha= scale) + + if motion_vec is not None and self.block_no % 5 == 0: + x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False) + return x class AudioProjModel(ModelMixin, ConfigMixin): @@ -898,6 +906,7 @@ def __init__(self, norm_input_visual=True, norm_output_audio=True, standin= False, + motion_encoder_dim=0, ): super().__init__() @@ -922,14 +931,15 @@ def __init__(self, self.flag_causal_attention = False self.block_mask = None self.inject_sample_info = inject_sample_info - + self.motion_encoder_dim = motion_encoder_dim self.norm_output_audio = norm_output_audio self.audio_window = audio_window self.intermediate_dim = intermediate_dim self.vae_scale = vae_scale multitalk = multitalk_output_dim > 0 - self.multitalk = multitalk + self.multitalk = multitalk + animate = motion_encoder_dim > 0 # embeddings self.patch_embedding = nn.Conv3d( @@ -1027,6 +1037,25 @@ def __init__(self, block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) + if animate: + self.pose_patch_embedding = nn.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size + ) + + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + ) + + def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): layer_list = [self.head, self.head.head, self.patch_embedding] target_dype= dtype @@ -1208,6 +1237,9 @@ def forward( ref_images_count = 0, standin_freqs = None, standin_ref = None, + pose_latents=None, + face_pixel_values=None, + ): # patch_dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype @@ -1240,9 +1272,18 @@ def forward( if bz > 1: y = y.expand(bz, -1, -1, -1, -1) x = torch.cat([x, y], dim=1) # embeddings - # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) x = self.patch_embedding(x).to(modulation_dtype) grid_sizes = x.shape[2:] + x_list[i] = x + y = None + + motion_vec_list = [] + for i, x in enumerate(x_list): + # animate embeddings + motion_vec = None + if pose_latents is not None: + x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values) + motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) x_og_shape = x.shape @@ -1250,7 +1291,7 @@ def forward( else: x = x.flatten(2).transpose(1, 2) x_list[i] = x - x, y = None, None + x = None block_mask = None @@ -1450,9 +1491,9 @@ def forward( continue x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) else: - for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)): + for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)): if should_calc: - x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs) + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs) del x context = hints = audio_embedding = None diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 6586176bd..ff1570f3a 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -3,10 +3,10 @@ import gradio as gr def test_class_i2v(base_model_type): - return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] + return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ] def text_oneframe_overlap(base_model_type): - return test_class_i2v(base_model_type) and not test_multitalk(base_model_type) + return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type) def test_class_1_3B(base_model_type): return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] @@ -17,6 +17,8 @@ def test_multitalk(base_model_type): def test_standin(base_model_type): return base_model_type in ["standin", "vace_standin_14B"] +def test_wan_5B(base_model_type): + return base_model_type in ["ti2v_2_2", "lucy_edit"] class family_handler(): @staticmethod @@ -36,7 +38,7 @@ def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_st def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] elif base_model_type in ["i2v_2_2"]: def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] else: # i2v @@ -83,11 +85,13 @@ def query_model_def(base_model_type, model_def): vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] extra_model_def["vace_class"] = vace_class - if test_multitalk(base_model_type): + if base_model_type in ["animate"]: + fps = 30 + elif test_multitalk(base_model_type): fps = 25 elif base_model_type in ["fantasy"]: fps = 23 - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): fps = 24 else: fps = 16 @@ -100,14 +104,14 @@ def query_model_def(base_model_type, model_def): extra_model_def.update({ "frames_minimum" : frames_minimum, "frames_steps" : frames_steps, - "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", + "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2", "multiple_submodels" : multiple_submodels, "guidance_max_phases" : 3, "skip_layer_guidance" : True, "cfg_zero" : True, "cfg_star" : True, "adaptive_projected_guidance" : True, - "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), + "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], "convert_image_guide_to_video" : True, @@ -146,6 +150,34 @@ def query_model_def(base_model_type, model_def): } # extra_model_def["at_least_one_image_ref_needed"] = True + if base_model_type in ["lucy_edit"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["guide_preprocessing"] = { + "selection": ["UV"], + "labels" : { "UV": "Control Video"}, + "visible": False, + } + + if base_model_type in ["animate"]: + extra_model_def["guide_custom_choices"] = { + "choices":[ + ("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"), + ("Replace Person in Control Video Person in Reference Image", "PVBAI"), + ], + "default": "KI", + "letters_filter": "PVBXAKI", + "label": "Type of Process", + "show_label" : False, + } + extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["extract_guide_from_window_start"] = True + extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" + extra_model_def["background_ref_outpainted"] = False + + + if vace_class: extra_model_def["guide_preprocessing"] = { "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"], @@ -157,16 +189,19 @@ def query_model_def(base_model_type, model_def): extra_model_def["image_ref_choices"] = { "choices": [("None", ""), - ("Inject only People / Objects", "I"), - ("Inject Landscape and then People / Objects", "KI"), - ("Inject Frames and then People / Objects", "FI"), + ("People / Objects", "I"), + ("Landscape followed by People / Objects (if any)", "KI"), + ("Positioned Frames followed by People / Objects (if any)", "FI"), ], "letters_filter": "KFI", } extra_model_def["lock_image_refs_ratios"] = True - extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["pad_guide_video"] = True + extra_model_def["guide_inpaint_color"] = 127.5 + extra_model_def["forced_guide_mask_inputs"] = True if base_model_type in ["standin"]: extra_model_def["lock_image_refs_ratios"] = True @@ -209,10 +244,12 @@ def query_model_def(base_model_type, model_def): "visible" : False, } - if vace_class or base_model_type in ["infinitetalk"]: + if vace_class or base_model_type in ["infinitetalk", "animate"]: image_prompt_types_allowed = "TVL" elif base_model_type in ["ti2v_2_2"]: image_prompt_types_allowed = "TSVL" + elif base_model_type in ["lucy_edit"]: + image_prompt_types_allowed = "TVL" elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]: image_prompt_types_allowed = "SVL" elif i2v: @@ -234,8 +271,8 @@ def query_model_def(base_model_type, model_def): def query_supported_types(): return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", - "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + "recam_1.3B", "animate", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] @staticmethod @@ -265,11 +302,12 @@ def query_family_infos(): @staticmethod def get_vae_block_size(base_model_type): - return 32 if base_model_type == "ti2v_2_2" else 16 + return 32 if test_wan_5B(base_model_type) else 16 @staticmethod def get_rgb_factors(base_model_type ): from shared.RGB_factors import get_rgb_factors + if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2" latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) return latent_rgb_factors, latent_rgb_factors_bias @@ -283,7 +321,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] }] - if base_model_type == "ti2v_2_2": + if test_wan_5B(base_model_type): download_def += [ { "repoId" : "DeepBeepMeep/Wan2.2", "sourceFolderList" : [""], @@ -377,8 +415,8 @@ def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "sample_solver": "unipc", }) - if test_class_i2v(base_model_type): - ui_defaults["image_prompt_type"] = "S" + if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]: + ui_defaults["image_prompt_type"] = "S" if base_model_type in ["fantasy"]: ui_defaults.update({ @@ -434,10 +472,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "image_prompt_type": "T", }) - if base_model_type in ["recam_1.3B"]: + if base_model_type in ["recam_1.3B", "lucy_edit"]: ui_defaults.update({ "video_prompt_type": "UV", }) + elif base_model_type in ["animate"]: + ui_defaults.update({ + "video_prompt_type": "PVBXAKI", + "mask_expand": 20, + }) if text_oneframe_overlap(base_model_type): ui_defaults["sliding_window_overlap"] = 1 diff --git a/shared/convert/convert_diffusers_to_flux.py b/shared/convert/convert_diffusers_to_flux.py new file mode 100644 index 000000000..608b1765e --- /dev/null +++ b/shared/convert/convert_diffusers_to_flux.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +""" +Convert a Flux model from Diffusers (folder or single-file) into the original +single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI. + +Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file) +Output : /path/to/flux1-your-model.safetensors (transformer only) + +Usage: + python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors + python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors + # optional quantization: + # --fp8 (float8_e4m3fn, simple) + # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors) +""" + +import argparse +import json +from pathlib import Path +from collections import OrderedDict + +import torch +from safetensors import safe_open +import safetensors.torch +from tqdm import tqdm + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("diffusers_path", type=str, + help="Path to Diffusers checkpoint folder OR a single .safetensors file.") + ap.add_argument("output_path", type=str, + help="Output .safetensors path for the Flux transformer.") + ap.add_argument("--fp8", action="store_true", + help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).") + ap.add_argument("--fp8-scaled", action="store_true", + help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.") + return ap.parse_args() + + +# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable). +DIFFUSERS_MAP = { + # global embeds + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + + # dual-stream (image/text) blocks + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + + # single-stream blocks + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + + # final + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + # these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift] + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +class DiffusersSource: + """ + Uniform interface over: + 1) Folder with index JSON + shards + 2) Folder with exactly one .safetensors (no index) + 3) Single .safetensors file + Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning) + """ + + POSSIBLE_PREFIXES = ["", "model."] # try in this order + + def __init__(self, path: Path): + p = Path(path) + if p.is_dir(): + # use 'transformer' subfolder if present + if (p / "transformer").is_dir(): + p = p / "transformer" + self._init_from_dir(p) + elif p.is_file() and p.suffix == ".safetensors": + self._init_from_single_file(p) + else: + raise FileNotFoundError(f"Invalid path: {p}") + + # ---------- common helpers ---------- + + @staticmethod + def _strip_prefix(k: str) -> str: + return k[6:] if k.startswith("model.") else k + + def _resolve(self, want: str): + """ + Return the actual stored key matching `want` by trying known prefixes. + """ + for pref in self.POSSIBLE_PREFIXES: + k = pref + want + if k in self._all_keys: + return k + return None + + def has(self, want: str) -> bool: + return self._resolve(want) is not None + + def get(self, want: str) -> torch.Tensor: + real_key = self._resolve(want) + if real_key is None: + raise KeyError(f"Missing key: {want}") + return self._get_by_real_key(real_key).to("cpu") + + @property + def base_keys(self): + # keys without 'model.' prefix for scanning + return [self._strip_prefix(k) for k in self._all_keys] + + # ---------- modes ---------- + + def _init_from_single_file(self, file_path: Path): + self._mode = "single" + self._file = file_path + self._handle = safe_open(file_path, framework="pt", device="cpu") + self._all_keys = list(self._handle.keys()) + + def _get_by_real_key(real_key: str): + return self._handle.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + + def _init_from_dir(self, dpath: Path): + index_json = dpath / "diffusion_pytorch_model.safetensors.index.json" + if index_json.exists(): + with open(index_json, "r", encoding="utf-8") as f: + index = json.load(f) + weight_map = index["weight_map"] # full mapping + self._mode = "sharded" + self._dpath = dpath + self._weight_map = {k: dpath / v for k, v in weight_map.items()} + self._all_keys = list(self._weight_map.keys()) + self._open_handles = {} + + def _get_by_real_key(real_key: str): + fpath = self._weight_map[real_key] + h = self._open_handles.get(fpath) + if h is None: + h = safe_open(fpath, framework="pt", device="cpu") + self._open_handles[fpath] = h + return h.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + return + + # no index: try exactly one safetensors in folder + files = sorted(dpath.glob("*.safetensors")) + if len(files) != 1: + raise FileNotFoundError( + f"No index found and {dpath} does not contain exactly one .safetensors file." + ) + self._init_from_single_file(files[0]) + + +def main(): + args = parse_args() + src = DiffusersSource(Path(args.diffusers_path)) + + # Count blocks by scanning base keys (with any 'model.' prefix removed) + num_dual = 0 + num_single = 0 + for k in src.base_keys: + if k.startswith("transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_dual = max(num_dual, i + 1) + except Exception: + pass + elif k.startswith("single_transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_single = max(num_single, i + 1) + except Exception: + pass + print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks") + + # Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0) + def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor: + shift, scale = vec.chunk(2, dim=0) + return torch.cat([scale, shift], dim=0) + + orig = {} + + # Per-block (dual) + for b in range(num_dual): + prefix = f"transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("double_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Per-block (single) + for b in range(num_single): + prefix = f"single_transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("single_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Globals (non-block) + for okey, dvals in DIFFUSERS_MAP.items(): + if okey.startswith(("double_blocks.", "single_blocks.")): + continue + dkeys = dvals + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey] = src.get(dkeys[0]) + else: + orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves + if "final_layer.adaLN_modulation.1.weight" in orig: + orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.weight"] + ) + if "final_layer.adaLN_modulation.1.bias" in orig: + orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.bias"] + ) + + # Optional FP8 variants (experimental; not required for ComfyUI/BFL) + if args.fp8 or args.fp8_scaled: + dtype = torch.float8_e4m3fn # noqa + minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max + + def stochastic_round_to(t): + t = t.float().clamp(minv, maxv) + lower = torch.floor(t * 256) / 256 + upper = torch.ceil(t * 256) / 256 + prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t)) + rnd = torch.rand_like(t) + out = torch.where(rnd < prob, upper, lower) + return out.to(dtype) + + def scale_to_8bit(weight, target_max=416.0): + absmax = weight.abs().max() + scale = absmax / target_max if absmax > 0 else torch.tensor(1.0) + scaled = (weight / scale).clamp(minv, maxv).to(dtype) + return scaled, scale + + scales = {} + for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"): + t = orig[k] + if args.fp8: + orig[k] = stochastic_round_to(t) + else: + if k.endswith(".weight") and t.dim() == 2: + qt, s = scale_to_8bit(t) + orig[k] = qt + scales[k[:-len(".weight")] + ".scale_weight"] = s + else: + orig[k] = t.clamp(minv, maxv).to(dtype) + if args.fp8_scaled: + orig.update(scales) + orig["scaled_fp8"] = torch.tensor([], dtype=dtype) + else: + # Default: save in bfloat16 + for k in list(orig.keys()): + orig[k] = orig[k].to(torch.bfloat16).cpu() + + out_path = Path(args.output_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + meta = OrderedDict() + meta["format"] = "pt" + meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d") + print(f"Saving transformer to: {out_path}") + safetensors.torch.save_file(orig, str(out_path), metadata=meta) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/shared/inpainting/__init__.py b/shared/inpainting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/shared/inpainting/lanpaint.py b/shared/inpainting/lanpaint.py new file mode 100644 index 000000000..3165e7b01 --- /dev/null +++ b/shared/inpainting/lanpaint.py @@ -0,0 +1,240 @@ +import torch +from .utils import * +from functools import partial + +# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/) + +def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape + + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + +def _unpack_latents(latents, height, width, vae_scale_factor=8): + batch_size, num_patches, channels = latents.shape + + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + +class LanPaint(): + def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False): + self.n_steps = NSteps + self.chara_lamb = Lambda + self.IS_FLUX = IS_FLUX + self.IS_FLOW = IS_FLOW + self.step_size = StepSize + self.friction = Friction + self.chara_beta = Beta + self.img_dim_size = None + def add_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (None,) * (self.img_dim_size-1) + return array[index] + def remove_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (0,) * (self.img_dim_size-1) + return array[index] + def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8): + latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor) + noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor) + x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor) + latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor) + self.height = height + self.width = width + self.vae_scale_factor = vae_scale_factor + self.img_dim_size = len(x.shape) + self.latent_image = latent_image + self.noise = noise + if n_steps is None: + n_steps = self.n_steps + out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW) + out = _pack_latents(out) + return out + def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW): + if IS_FLUX: + cfg_BIG = 1.0 + + def double_denoise(latents, t): + latents = _pack_latents(latents) + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None, None + predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor) + if true_cfg_scale == cfg_BIG: + predict_big = predict_std + else: + predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t) + predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor) + return predict_std, predict_big + + if len(sigma.shape) == 0: + sigma = torch.tensor([sigma.item()]) + latent_mask = 1 - latent_mask + if IS_FLUX or IS_FLOW: + Flow_t = sigma + abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 ) + VE_Sigma = Flow_t / (1 - Flow_t) + #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item()) + else: + VE_Sigma = sigma + abt = 1/( 1+VE_Sigma**2 ) + Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 ) + # VE_Sigma, abt, Flow_t = current_times + current_times = (VE_Sigma, abt, Flow_t) + + step_size = self.step_size * (1 - abt) + step_size = self.add_none_dims(step_size) + # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values + # This is the replace step + # x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask + + noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma + x = x * (1 - latent_mask) + noisy_image * latent_mask + + if IS_FLUX or IS_FLOW: + x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + + ############ LanPaint Iterations Start ############### + # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region. + args = None + for i in range(n_steps): + score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise ) + if score_func is None: return None + x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args) + if IS_FLUX or IS_FLOW: + x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + ############ LanPaint Iterations End ############### + # out is x_0 + # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed) + # out = out * (1-latent_mask) + self.latent_image * latent_mask + # return out + return x + + def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func): + + lamb = self.chara_lamb + if self.IS_FLUX or self.IS_FLOW: + # compute t for flow model, with a small epsilon compensating for numerical error. + x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow)) + if x_0 is None: return None + else: + x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma)) + if x_0 is None: return None + + score_x = -(x_t - x_0) + score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG) + return score_x * (1 - mask) + score_y * mask + def sigma_x(self, abt): + # the time scale for the x_t update + return abt**0 + def sigma_y(self, abt): + beta = self.chara_beta * abt ** 0 + return beta + + def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None): + # prepare the step size and time parameters + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y) + sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes + # print('mask',mask.device) + if torch.mean(dtx) <= 0.: + return x_t, args + # ------------------------------------------------------------------------- + # Compute the Langevin dynamics update in variance perserving notation + # ------------------------------------------------------------------------- + #x0 = self.x0_evalutation(x_t, score, sigma, args) + #C = abt**0.5 * x0 / (1-abt) + A = A_x * (1-mask) + A_y * mask + D = D_x * (1-mask) + D_y * mask + dt = dtx * (1-mask) + dty * mask + Gamma = Gamma_x * (1-mask) + Gamma_y * mask + + + def Coef_C(x_t): + x0 = self.x0_evalutation(x_t, score, sigma, args) + C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t + return C + def advance_time(x_t, v, dt, Gamma, A, C, D): + dtype = x_t.dtype + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + osc = StochasticHarmonicOscillator(Gamma, A, C, D ) + x_t, v = osc.dynamics(x_t, v, dt ) + x_t = x_t.to(dtype) + v = v.to(dtype) + return x_t, v + if args is None: + #v = torch.zeros_like(x_t) + v = None + C = Coef_C(x_t) + #print(torch.squeeze(dtx), torch.squeeze(dty)) + x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D) + else: + v, C = args + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C_new = Coef_C(x_t) + v = v + Gamma**0.5 * ( C_new - C) *dt + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C = C_new + + return x_t, (v, C) + + def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y): + # ------------------------------------------------------------------------- + # Unpack current times parameters (sigma and abt) + sigma, abt, flow_t = current_times + sigma = self.add_none_dims(sigma) + abt = self.add_none_dims(abt) + # Compute time step (dtx, dty) for x and y branches. + dtx = 2 * step_size * sigma_x + dty = 2 * step_size * sigma_y + + # ------------------------------------------------------------------------- + # Define friction parameter Gamma_hat for each branch. + # Using dtx**0 provides a tensor of the proper device/dtype. + + Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0 + Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0 + #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item()) + # adjust dt to match denoise-addnoise steps sizes + Gamma_hat_x /= 2. + Gamma_hat_y /= 2. + A_t_x = (1) / ( 1 - abt ) * dtx / 2 + A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2 + + + A_x = A_t_x / (dtx/2) + A_y = A_t_y / (dty/2) + Gamma_x = Gamma_hat_x / (dtx/2) + Gamma_y = Gamma_hat_y / (dty/2) + + #D_x = (2 * (1 + sigma**2) )**0.5 + #D_y = (2 * (1 + sigma**2) )**0.5 + D_x = (2 * abt**0 )**0.5 + D_y = (2 * abt**0 )**0.5 + return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y + + + + def x0_evalutation(self, x_t, score, sigma, args): + x0 = x_t + score(x_t) + return x0 \ No newline at end of file diff --git a/shared/inpainting/utils.py b/shared/inpainting/utils.py new file mode 100644 index 000000000..c017ab0a8 --- /dev/null +++ b/shared/inpainting/utils.py @@ -0,0 +1,301 @@ +import torch +def epxm1_x(x): + # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero. + result = torch.special.expm1(x) / x + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x) < 1e-2 + result = torch.where(mask, 1 + x/2. + x**2 / 6., result) + return result +def epxm1mx_x2(x): + # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x) / x**2 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**2) < 1e-2 + result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result) + return result + +def expm1mxmhx2_x3(x): + # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x - x**2 / 2) / x**3 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**3) < 1e-2 + result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result) + return result + +def exp_1mcosh_GD(gamma_t, delta): + """ + Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) ) + numerator = torch.where(is_positive, numerator_pos, numerator_neg) + result = numerator / (delta * gamma_t**2 ) + # Handle NaN/inf cases + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + # Handle numerical instability for small delta + mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2 + taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t) + result = torch.where(mask, taylor, result) + return result + +def exp_sinh_GsqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + denominator_pos = gamma_t_sqrt_delta + result_pos = numerator_pos / gamma_t_sqrt_delta + result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos)) + + # Taylor expansion for small gamma_t_sqrt_delta + mask = torch.abs(gamma_t_sqrt_delta) < 1e-2 + taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t) + result_pos = torch.where(mask, taylor, result_pos) + + # Handle negative delta + result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi) + result = torch.where(is_positive, result_pos, result_neg) + return result + +def exp_cosh(gamma_t, delta): + """ + Compute e^(-Γt) * cosh(Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result + return result +def exp_sinh_sqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / √Δ + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + Returns: + Result of the computation with numerical stability handling + """ + exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + result = gamma_t * exp_sinh_GsqrtD_result + return result + + + +def zeta1(gamma_t, delta): + # Compute hyperbolic terms and exponential + half_gamma_t = gamma_t / 2 + exp_cosh_term = exp_cosh(half_gamma_t, delta) + exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta) + + + # Main computation + numerator = 1 - (exp_cosh_term + exp_sinh_term) + denominator = gamma_t * (1 - delta) / 4 + result = 1 - numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small x (similar to your epxm1Dx approach) + mask = torch.abs(denominator) < 5e-3 + term1 = epxm1_x(-gamma_t) + term2 = epxm1mx_x2(-gamma_t) + term3 = expm1mxmhx2_x3(-gamma_t) + taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2 + result = torch.where(mask, taylor, result) + + return result + +def exp_cosh_minus_terms(gamma_t, delta): + """ + Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ)) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_term = torch.exp(-gamma_t) + # Compute individual terms + exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term + exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term + + #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + # Main computation + numerator = exp_cosh_term - exp_cosh_delta_term + denominator = gamma_t * (1 - delta) + + result = numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small gamma_t and delta near 1 + mask = (torch.abs(denominator) < 1e-1) + exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0) + taylor = ( + gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) + - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) ) + ) + result = torch.where(mask, taylor, result) + + return result + + +def zeta2(gamma_t, delta): + half_gamma_t = gamma_t / 2 + return exp_sinh_GsqrtD(half_gamma_t, delta) + +def sig11(gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + + +def Zcoefs(gamma_t, delta): + Zeta1 = zeta1(gamma_t, delta) + Zeta2 = zeta2(gamma_t, delta) + + sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8 + amplitude = torch.sqrt(sq_total) + Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude + Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5 + #cterm = exp_cosh_minus_terms(gamma_t, delta) + #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta) + #Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) ) + Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) ) + + return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude + +def Zcoefs_asymp(gamma_t, delta): + A_t = (gamma_t * (1 - delta) )/4 + return epxm1_x(- 2 * A_t) + +class StochasticHarmonicOscillator: + """ + Simulates a stochastic harmonic oscillator governed by the equations: + dy(t) = q(t) dt + dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt + + Also define v(t) = q(t) / √Γ, which is numerically more stable. + + Where: + y(t) - Position variable + q(t) - Velocity variable + Γ - Damping coefficient + A - Harmonic potential strength + C - Constant force term + D - Noise amplitude + dw(t) - Wiener process (Brownian motion) + """ + def __init__(self, Gamma, A, C, D): + self.Gamma = Gamma + self.A = A + self.C = C + self.D = D + self.Delta = 1 - 4 * A / Gamma + def sig11(self, gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + def sig22(self, gamma_t, delta): + return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) + def dynamics(self, y0, v0, t): + """ + Calculates the position and velocity variables at time t. + + Parameters: + y0 (float): Initial position + v0 (float): Initial velocity v(0) = q(0) / √Γ + t (float): Time at which to evaluate the dynamics + Returns: + tuple: (y(t), v(t)) + """ + + dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0 + Delta = self.Delta + dummyzero + Gamma_hat = self.Gamma * t + dummyzero + A = self.A + dummyzero + C = self.C + dummyzero + D = self.D + dummyzero + Gamma = self.Gamma + dummyzero + zeta_1 = zeta1( Gamma_hat, Delta) + zeta_2 = zeta2( Gamma_hat, Delta) + EE = 1 - Gamma_hat * zeta_2 + + if v0 is None: + v0 = torch.randn_like(y0) * D / 2 ** 0.5 + #v0 = (C - A * y0)/Gamma**0.5 + + # Calculate mean position and velocity + term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t + y_mean = term1 + y0 + v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0 + + cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta) + cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2 + cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5) + + # sample new position and velocity with multivariate normal distribution + + batch_shape = y0.shape + cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + cov_matrix[..., 0, 0] = cov_yy + cov_matrix[..., 0, 1] = cov_yv + cov_matrix[..., 1, 0] = cov_yv # symmetric + cov_matrix[..., 1, 1] = cov_vv + + + + # Compute the Cholesky decomposition to get scale_tril + #scale_tril = torch.linalg.cholesky(cov_matrix) + scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + tol = 1e-8 + cov_yy = torch.clamp( cov_yy, min = tol ) + sd_yy = torch.sqrt( cov_yy ) + inv_sd_yy = 1/(sd_yy) + + scale_tril[..., 0, 0] = sd_yy + scale_tril[..., 0, 1] = 0. + scale_tril[..., 1, 0] = cov_yv * inv_sd_yy + scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5 + # check if it matches torch.linalg. + #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 ) + # Sample correlated noise from multivariate normal + mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype) + mean[..., 0] = y_mean + mean[..., 1] = v_mean + new_yv = torch.distributions.MultivariateNormal( + loc=mean, + scale_tril=scale_tril + ).sample() + + return new_yv[...,0], new_yv[...,1] \ No newline at end of file diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py index b24530d5c..224cf3412 100644 --- a/shared/utils/audio_video.py +++ b/shared/utils/audio_video.py @@ -232,6 +232,9 @@ def save_video(tensor, retry=5): """Save tensor as video with configurable codec and container options.""" + if torch.is_tensor(tensor) and len(tensor.shape) == 4: + tensor = tensor.unsqueeze(0) + suffix = f'.{container}' cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file if not cache_file.endswith(suffix): diff --git a/shared/utils/download.py b/shared/utils/download.py new file mode 100644 index 000000000..ed035c0b7 --- /dev/null +++ b/shared/utils/download.py @@ -0,0 +1,110 @@ +import sys, time + +# Global variables to track download progress +_start_time = None +_last_time = None +_last_downloaded = 0 +_speed_history = [] +_update_interval = 0.5 # Update speed every 0.5 seconds + +def progress_hook(block_num, block_size, total_size, filename=None): + """ + Simple progress bar hook for urlretrieve + + Args: + block_num: Number of blocks downloaded so far + block_size: Size of each block in bytes + total_size: Total size of the file in bytes + filename: Name of the file being downloaded (optional) + """ + global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval + + current_time = time.time() + downloaded = block_num * block_size + + # Initialize timing on first call + if _start_time is None or block_num == 0: + _start_time = current_time + _last_time = current_time + _last_downloaded = 0 + _speed_history = [] + + # Calculate download speed only at specified intervals + speed = 0 + if current_time - _last_time >= _update_interval: + if _last_time > 0: + current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) + _speed_history.append(current_speed) + # Keep only last 5 speed measurements for smoothing + if len(_speed_history) > 5: + _speed_history.pop(0) + # Average the recent speeds for smoother display + speed = sum(_speed_history) / len(_speed_history) + + _last_time = current_time + _last_downloaded = downloaded + elif _speed_history: + # Use the last calculated average speed + speed = sum(_speed_history) / len(_speed_history) + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + file_display = filename if filename else "Unknown file" + + if total_size <= 0: + # If total size is unknown, show downloaded bytes + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(80)) + sys.stdout.flush() + return + + downloaded = block_num * block_size + percent = min(100, (downloaded / total_size) * 100) + + # Create progress bar (40 characters wide to leave room for other info) + bar_length = 40 + filled = int(bar_length * percent / 100) + bar = '█' * filled + '░' * (bar_length - filled) + + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + + # Display progress with filename first + line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(100)) + sys.stdout.flush() + + # Print newline when complete + if percent >= 100: + print() + +# Wrapper function to include filename in progress hook +def create_progress_hook(filename): + """Creates a progress hook with the filename included""" + global _start_time, _last_time, _last_downloaded, _speed_history + # Reset timing variables for new download + _start_time = None + _last_time = None + _last_downloaded = 0 + _speed_history = [] + + def hook(block_num, block_size, total_size): + return progress_hook(block_num, block_size, total_size, filename) + return hook + + diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 17b9dde41..6e8a98bf2 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os import os.path as osp @@ -176,8 +175,9 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = -1): - t = t[:, frame_no] if frame_no >= 0 else t +def convert_tensor_to_image(t, frame_no = 0): + if len(t.shape) == 4: + t = t[:, frame_no] return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): @@ -257,16 +257,18 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ): if rm_background: session = new_session() output_list =[] + output_mask_list =[] for i, img in enumerate(img_list): width, height = img.size + resized_mask = None if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: - if outpainting_dims is not None: - resized_image =img + if outpainting_dims is not None and background_ref_outpainted: + resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) elif img.size != (budget_width, budget_height): resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) else: @@ -290,145 +292,103 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, - return output_list + output_mask_list.append(resized_mask) + return output_list, output_mask_list +def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): + from shared.utils.utils import save_image + inpaint_color = canvas_tf_bg / 127.5 - 1 - - -def str2bool(v): - """ - Convert a string to a boolean. - - Supported true values: 'yes', 'true', 't', 'y', '1' - Supported false values: 'no', 'false', 'f', 'n', '0' - - Args: - v (str): String to convert. - - Returns: - bool: Converted boolean value. - - Raises: - argparse.ArgumentTypeError: If the value cannot be converted to boolean. - """ - if isinstance(v, bool): - return v - v_lower = v.lower() - if v_lower in ('yes', 'true', 't', 'y', '1'): - return True - elif v_lower in ('no', 'false', 'f', 'n', '0'): - return False + ref_width, ref_height = ref_img.size + if (ref_height, ref_width) == image_size and outpainting_dims == None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + canvas = torch.zeros_like(ref_img) if return_mask else None else: - raise argparse.ArgumentTypeError('Boolean value expected (True/False)') - - -import sys, time - -# Global variables to track download progress -_start_time = None -_last_time = None -_last_downloaded = 0 -_speed_history = [] -_update_interval = 0.5 # Update speed every 0.5 seconds - -def progress_hook(block_num, block_size, total_size, filename=None): - """ - Simple progress bar hook for urlretrieve - - Args: - block_num: Number of blocks downloaded so far - block_size: Size of each block in bytes - total_size: Total size of the file in bytes - filename: Name of the file being downloaded (optional) - """ - global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval - - current_time = time.time() - downloaded = block_num * block_size - - # Initialize timing on first call - if _start_time is None or block_num == 0: - _start_time = current_time - _last_time = current_time - _last_downloaded = 0 - _speed_history = [] - - # Calculate download speed only at specified intervals - speed = 0 - if current_time - _last_time >= _update_interval: - if _last_time > 0: - current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) - _speed_history.append(current_speed) - # Keep only last 5 speed measurements for smoothing - if len(_speed_history) > 5: - _speed_history.pop(0) - # Average the recent speeds for smoother display - speed = sum(_speed_history) / len(_speed_history) - - _last_time = current_time - _last_downloaded = downloaded - elif _speed_history: - # Use the last calculated average speed - speed = sum(_speed_history) / len(_speed_history) - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - file_display = filename if filename else "Unknown file" - - if total_size <= 0: - # If total size is unknown, show downloaded bytes - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(80)) - sys.stdout.flush() - return - - downloaded = block_num * block_size - percent = min(100, (downloaded / total_size) * 100) - - # Create progress bar (40 characters wide to leave room for other info) - bar_length = 40 - filled = int(bar_length * percent / 100) - bar = '█' * filled + '░' * (bar_length - filled) - - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - - # Display progress with filename first - line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(100)) - sys.stdout.flush() - - # Print newline when complete - if percent >= 100: - print() - -# Wrapper function to include filename in progress hook -def create_progress_hook(filename): - """Creates a progress hook with the filename included""" - global _start_time, _last_time, _last_downloaded, _speed_history - # Reset timing variables for new download - _start_time = None - _last_time = None - _last_downloaded = 0 - _speed_history = [] - - def hook(block_num, block_size, total_size): - return progress_hook(block_num, block_size, total_size, filename) - return hook + if outpainting_dims != None: + final_height, final_width = image_size + canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) + else: + canvas_height, canvas_width = image_size + if full_frame: + new_height = canvas_height + new_width = canvas_width + top = left = 0 + else: + # if fill_max and (canvas_height - new_height) < 16: + # new_height = canvas_height + # if fill_max and (canvas_width - new_width) < 16: + # new_width = canvas_width + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if outpainting_dims != None: + canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img + else: + canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = canvas + canvas = None + if return_mask: + if outpainting_dims != None: + canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 + else: + canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = 0 + canvas = canvas.to(device) + if return_image: + return convert_tensor_to_image(ref_img), canvas + + return ref_img.to(device), canvas + +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ): + src_videos, src_masks = [], [] + inpaint_color = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0 + for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): + src_video = src_mask = None + if cur_video_guide is not None: + src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w + if cur_video_mask is not None and any_mask: + src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w + if pre_video_guide is not None and not extract_guide_from_window_start: + src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) + if any_mask: + src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) + if src_video is None: + if any_guide_padding: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu") + if any_mask: + src_mask = torch.zeros_like(src_video[0:1]) + elif src_video.shape[1] < current_video_length: + if any_guide_padding: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1) + if cur_video_mask is not None and any_mask: + src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + else: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + if any_mask: + src_mask = src_mask[:, :new_num_frames] + + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color + src_mask[:, pos:pos+1] = 1 + + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True) + + src_videos.append(src_video) + src_masks.append(src_mask) + return src_videos, src_masks diff --git a/wgp.py b/wgp.py index ee6b5f62e..0e6a0b20e 100644 --- a/wgp.py +++ b/wgp.py @@ -394,7 +394,7 @@ def ret(): gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") if "F" in video_prompt_type: if len(frames_positions.strip()) > 0: - positions = frames_positions.split(" ") + positions = frames_positions.replace(","," ").split(" ") for pos_str in positions: if not pos_str in ["L", "l"] and len(pos_str)>0: if not is_integer(pos_str): @@ -2528,7 +2528,7 @@ def computeList(filename): from urllib.request import urlretrieve - from shared.utils.utils import create_progress_hook + from shared.utils.download import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", @@ -3726,6 +3726,60 @@ def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canva input_mask = convert_tensor_to_image(full_frame) return input_image, input_mask +def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): + if not input_video_path or max_frames <= 0: + return None, None + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + + any_mask = input_mask_path != None + video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) + if len(video) == 0: return None + if any_mask: + mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) + frame_height, frame_width, _ = video[0].shape + + num_frames = min(len(video), len(mask_video)) + if num_frames == 0: return None + video, mask_video = video[:num_frames], mask_video[:num_frames] + + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + + face_list = [] + for frame_idx in range(num_frames): + frame = video[frame_idx].cpu().numpy() + # video[frame_idx] = None + if any_mask: + mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) + # mask_video[frame_idx] = None + if (frame_width, frame_height) != mask.size: + mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS) + mask = np.array(mask) + alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8) + alpha_mask[mask > 127] = 1 + frame = frame * alpha_mask + frame = Image.fromarray(frame) + face = face_processor.process(frame, resize_to=size, face_crop_scale = 1) + face_list.append(face) + + face_processor = None + gc.collect() + torch.cuda.empty_cache() + + face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w + if pad_frames > 0: + face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2) + + if args.save_masks: + from preprocessing.dwpose.pose import save_one_video + saved_faces_frames = [np.array(face) for face in face_list ] + save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) + return face_tensor + def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): @@ -3742,7 +3796,13 @@ def mask_to_xyxy_box(mask): box = [xmin, ymin, xmax, ymax] box = [int(x) for x in box] return box - + inpaint_color = int(inpaint_color) + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + if not input_video_path or max_frames <= 0: return None, None any_mask = input_mask_path != None @@ -3909,6 +3969,9 @@ def prep_prephase(frame_idx): preproc_outside = None gc.collect() torch.cuda.empty_cache() + if pad_frames > 0: + masked_frames = masked_frames[0] * pad_frames + masked_frames + if any_mask: masked_frames = masks[0] * pad_frames + masks return torch.stack(masked_frames), torch.stack(masks) if any_mask else None @@ -4646,7 +4709,8 @@ def remove_temp_filenames(temp_filenames_list): current_video_length = video_length # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 - + guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) + extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) i2v = test_class_i2v(model_type) diffusion_forcing = "diffusion_forcing" in model_filename t2v = base_model_type in ["t2v"] @@ -4662,6 +4726,7 @@ def remove_temp_filenames(temp_filenames_list): multitalk = model_def.get("multitalk_class", False) standin = model_def.get("standin_class", False) infinitetalk = base_model_type in ["infinitetalk"] + animate = base_model_type in ["animate"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4822,7 +4887,6 @@ def remove_temp_filenames(temp_filenames_list): repeat_no = 0 extra_generation = 0 initial_total_windows = 0 - discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length if sliding_window: @@ -4843,7 +4907,7 @@ def remove_temp_filenames(temp_filenames_list): if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = None + src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4899,7 +4963,7 @@ def remove_temp_filenames(temp_filenames_list): return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images = image_refs + src_ref_images, src_ref_masks = image_refs, None image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: @@ -4943,16 +5007,52 @@ def remove_temp_filenames(temp_filenames_list): from models.wan.multitalk.multitalk import get_window_audio_embeddings # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) - if vace: - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None + + if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0: + frames_positions_list = [] + if frames_positions is not None and len(frames_positions)> 0: + positions = frames_positions.replace(","," ").split(" ") + cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) + last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count + joker_used = False + project_window_no = 1 + for pos in positions : + if len(pos) > 0: + if pos in ["L", "l"]: + cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length + if cur_end_pos >= last_frame_no and not joker_used: + joker_used = True + cur_end_pos = last_frame_no -1 + project_window_no += 1 + frames_positions_list.append(cur_end_pos) + cur_end_pos -= sliding_window_discard_last_frames + reuse_frames + else: + frames_positions_list.append(int(pos)-1 + alignment_shift) + frames_positions_list = frames_positions_list[:len(image_refs)] + nb_frames_positions = len(frames_positions_list) + if nb_frames_positions > 0: + frames_to_inject = [None] * (max(frames_positions_list) + 1) + for i, pos in enumerate(frames_positions_list): + frames_to_inject[pos] = image_refs[i] + if video_guide is not None: - keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) + keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") - keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] - - if vace: + guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] + keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] + guide_frames_extract_count = len(keep_frames_parsed) + + if "B" in video_prompt_type: + send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) + src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) + if src_faces is not None and src_faces.shape[1] < current_video_length: + src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) + + if vace or animate: + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None context_scale = [ control_net_weight] if "V" in video_prompt_type: process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) @@ -4971,10 +5071,10 @@ def remove_temp_filenames(temp_filenames_list): if preprocess_type2 is not None: context_scale = [ control_net_weight /2, control_net_weight2 /2] send_cmd("progress", [0, get_latest_status(state, status_info)]) - inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask=="inpaint" else 127 - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) + inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) if video_guide_processed != None: if sample_fit_canvas != None: @@ -4985,7 +5085,37 @@ def remove_temp_filenames(temp_filenames_list): refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] if video_mask_processed != None: refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - elif ltxv: + + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + + if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)): + any_mask = True + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2], + [video_mask_processed, video_mask_processed2], + pre_video_guide, image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) + else: + src_faces = src_faces[:, :src_video.shape[1]] + if args.save_masks: + save_video( src_video, "masked_frames.mp4", fps) + if src_video2 is not None: + save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: + save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + + elif ltxv: preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") status_info = "Extracting " + processes_names[preprocess_type] send_cmd("progress", [0, get_latest_status(state, status_info)]) @@ -5023,7 +5153,7 @@ def remove_temp_filenames(temp_filenames_list): sample_fit_canvas = None else: # video to video - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps) + video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size) if video_guide_processed is None: src_video = pre_video_guide else: @@ -5043,29 +5173,6 @@ def remove_temp_filenames(temp_filenames_list): refresh_preview["image_mask"] = new_image_mask if window_no == 1 and image_refs is not None and len(image_refs) > 0: - if repeat_no == 1: - frames_positions_list = [] - if frames_positions is not None and len(frames_positions)> 0: - positions = frames_positions.split(" ") - cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) #if reset_control_aligment else 0 - last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count - joker_used = False - project_window_no = 1 - for pos in positions : - if len(pos) > 0: - if pos in ["L", "l"]: - cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length - if cur_end_pos >= last_frame_no and not joker_used: - joker_used = True - cur_end_pos = last_frame_no -1 - project_window_no += 1 - frames_positions_list.append(cur_end_pos) - cur_end_pos -= sliding_window_discard_last_frames + reuse_frames - else: - frames_positions_list.append(int(pos)-1 + alignment_shift) - frames_positions_list = frames_positions_list[:len(image_refs)] - nb_frames_positions = len(frames_positions_list) - if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : from shared.utils.utils import get_outpainting_full_area_dimensions w, h = image_refs[0].size @@ -5089,20 +5196,16 @@ def remove_temp_filenames(temp_filenames_list): if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], + image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], remove_background_images_ref > 0, any_background_ref, fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, block_size=block_size, - outpainting_dims =outpainting_dims ) + outpainting_dims =outpainting_dims, + background_ref_outpainted = model_def.get("background_ref_outpainted", True) ) refresh_preview["image_refs"] = image_refs - if nb_frames_positions > 0: - frames_to_inject = [None] * (max(frames_positions_list) + 1) - for i, pos in enumerate(frames_positions_list): - frames_to_inject[pos] = image_refs[i] if vace : - frames_to_inject_parsed = frames_to_inject[guide_start_frame: guide_end_frame] image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], @@ -5116,7 +5219,7 @@ def remove_temp_filenames(temp_filenames_list): any_background_ref = any_background_ref ) if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + aligned_guide_start_frame - aligned_window_start_frame) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] if any_background_ref: new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] else: @@ -5165,10 +5268,14 @@ def set_header_text(txt): input_prompt = prompt, image_start = image_start_tensor, image_end = image_end_tensor, - input_frames = src_video, + input_frames = src_video, + input_frames2 = src_video2, input_ref_images= src_ref_images, + input_ref_masks = src_ref_masks, input_masks = src_mask, + input_masks2 = src_mask2, input_video= pre_video_guide, + input_faces = src_faces, denoising_strength=denoising_strength, prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, frame_num= (current_video_length // latent_size)* latent_size + 1, @@ -5302,6 +5409,7 @@ def set_header_text(txt): send_cmd("output") else: sample = samples.cpu() + abort = not is_image and sample.shape[1] < current_video_length # if True: # for testing # torch.save(sample, "output.pt") # else: @@ -6980,7 +7088,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): old_video_prompt_type = video_prompt_type - video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUV") + video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type model_type = state["model_type"] @@ -7437,7 +7545,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) image_prompt_type_choices = [] if "T" in image_prompt_types_allowed: - image_prompt_type_choices += [("Text Prompt Only", "")] + image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")] if "S" in image_prompt_types_allowed: image_prompt_type_choices += [("Start Video with Image", "S")] any_start_image = True @@ -7516,7 +7624,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") video_prompt_type_video_guide = gr.Dropdown( guide_preprocessing_choices, - value=filter_letters(video_prompt_type_value, "PDESLCMUV", guide_preprocessing.get("default", "") ), + value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ), label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, ) any_control_video = True @@ -7560,13 +7668,13 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl } mask_preprocessing_choices = [] - mask_preprocessing_labels = guide_preprocessing.get("labels", {}) + mask_preprocessing_labels = mask_preprocessing.get("labels", {}) for process_type in mask_preprocessing["selection"]: process_label = mask_preprocessing_labels.get(process_type, None) process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label mask_preprocessing_choices.append( (process_label, process_type) ) - video_prompt_type_video_mask_label = guide_preprocessing.get("label", "Area Processed") + video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed") video_prompt_type_video_mask = gr.Dropdown( mask_preprocessing_choices, value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), @@ -7591,7 +7699,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl choices= image_ref_choices["choices"], value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), visible = image_ref_choices.get("visible", True), - label=image_ref_choices.get("label", "Ref. Images Type"), show_label= True, scale = 2 + label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2 ) image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) @@ -7634,7 +7742,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -7649,14 +7757,14 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) - image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will start a new Clip)" if infinitetalk else "") + image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "") image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None - background_removal_label = model_def.get("background_removal_label", "Remove Backgrounds behind People / Objects") + background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects") remove_background_images_ref = gr.Dropdown( choices=[ @@ -7664,7 +7772,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl (background_removal_label, 1), ], value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1), - label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal + label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal ) any_audio_voices_support = any_audio_track(base_model_type) @@ -8084,7 +8192,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai ("Aligned to the beginning of the First Window of the new Video Sample", "T"), ], value=filter_letters(video_prompt_type_value, "T"), - label="Control Video / Injected Frames / Control Audio temporal alignment when any Video to continue", + label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue", visible = vace or ltxv or t2v or infinitetalk ) From e28c95ae912addee1abdb1b798bd0f554cd658cb Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 23 Sep 2025 23:04:44 +0200 Subject: [PATCH 030/155] Vace Contenders are in Town --- README.md | 8 + defaults/lucy_edit.json | 3 +- defaults/lucy_edit_fastwan.json | 16 ++ models/flux/flux_handler.py | 2 +- models/flux/flux_main.py | 14 +- models/hyvideo/hunyuan.py | 14 +- models/ltx_video/ltxv_handler.py | 2 + models/qwen/qwen_handler.py | 2 +- models/qwen/qwen_main.py | 12 +- models/wan/animate/animate_utils.py | 143 ++++++++++ models/wan/animate/face_blocks.py | 382 +++++++++++++++++++++++++ models/wan/animate/model_animate.py | 31 ++ models/wan/animate/motion_encoder.py | 308 ++++++++++++++++++++ models/wan/any2video.py | 153 +++------- models/wan/df_handler.py | 1 + models/wan/wan_handler.py | 8 +- shared/utils/utils.py | 102 ++++--- wgp.py | 406 ++++++++++----------------- 18 files changed, 1160 insertions(+), 447 deletions(-) create mode 100644 defaults/lucy_edit_fastwan.json create mode 100644 models/wan/animate/animate_utils.py create mode 100644 models/wan/animate/face_blocks.py create mode 100644 models/wan/animate/model_animate.py create mode 100644 models/wan/animate/motion_encoder.py diff --git a/README.md b/README.md index 8c5e4367a..6c338ac7c 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,14 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena ! + +So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: +- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. + +- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. + + ### September 15 2025: WanGP v8.6 - Attack of the Clones - The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json index 6344dfffd..a8f67ad64 100644 --- a/defaults/lucy_edit.json +++ b/defaults/lucy_edit.json @@ -2,7 +2,7 @@ "model": { "name": "Wan2.2 Lucy Edit 5B", "architecture": "lucy_edit", - "description": "Lucy Edit Dev is a video editing model that performs instruction-guided edits on videos using free-text prompts \u2014 it supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", @@ -10,6 +10,7 @@ ], "group": "wan2_2" }, + "prompt": "change the clothes to red", "video_length": 81, "guidance_scale": 5, "flow_shift": 5, diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json new file mode 100644 index 000000000..c67c795d7 --- /dev/null +++ b/defaults/lucy_edit_fastwan.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Wan2.2 FastWan Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", + "URLs": "lucy_edit", + "group": "wan2_2", + "loras": "ti2v_2_2" + }, + "prompt": "change the clothes to red", + "video_length": 81, + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 5, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index d16888122..83de7c3fc 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -56,7 +56,7 @@ def query_model_def(base_model_type, model_def): } - extra_model_def["lock_image_refs_ratios"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 return extra_model_def diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 686371129..6746a2375 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -142,8 +142,8 @@ def generate( n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, - image_guide= None, - image_mask= None, + input_frames= None, + input_masks= None, width= 832, height=480, embedded_guidance_scale: float = 2.5, @@ -197,10 +197,12 @@ def generate( for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] - elif image_guide is not None: - input_ref_images = [image_guide] + elif input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] else: input_ref_images = None + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + if self.name in ['flux-dev-uso', 'flux-dev-umo'] : inp, height, width = prepare_multi_ip( @@ -253,8 +255,8 @@ def unpack_latent(x): if image_mask is not None: from shared.utils.utils import convert_image_to_tensor img_msk_rebuilt = inp["img_msk_rebuilt"] - img= convert_image_to_tensor(image_guide) - x = img.squeeze(2) * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt + img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) + x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt x = x.clamp(-1, 1) x = x.transpose(0, 1) diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py index 181a9a774..aa6c3b3e6 100644 --- a/models/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -865,7 +865,7 @@ def generate( freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) else: if input_frames != None: - target_height, target_width = input_frames.shape[-3:-1] + target_height, target_width = input_frames.shape[-2:] elif input_video != None: target_height, target_width = input_video.shape[-2:] @@ -894,9 +894,10 @@ def generate( pixel_value_bg = input_video.unsqueeze(0) pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0) if input_frames != None: - pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float() + # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() + pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0) if input_video != None: pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2) pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2) @@ -908,10 +909,11 @@ def generate( if pixel_value_bg.shape[2] < frame_num: padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:]) pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2) - pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample() - pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.) + pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1 mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample() bg_latents = torch.cat([bg_latents, mask_latents], dim=1) bg_latents.mul_(self.vae.config.scaling_factor) diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index 2845fdb0b..8c322e153 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -35,6 +35,8 @@ def query_model_def(base_model_type, model_def): "selection": ["", "A", "NA", "XA", "XNA"], } + extra_model_def["extra_control_frames"] = 1 + extra_model_def["dont_cat_preguide"]= True return extra_model_def @staticmethod diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 010298ead..cc6a76484 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -17,7 +17,7 @@ def query_model_def(base_model_type, model_def): ("Default", "default"), ("Lightning", "lightning")], "guidance_max_phases" : 1, - "lock_image_refs_ratios": True, + "fit_into_canvas_image_refs": 0, } if base_model_type in ["qwen_image_edit_20B"]: diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index cec0b8e05..abd5c5cfe 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -17,7 +17,7 @@ from diffusers import FlowMatchEulerDiscreteScheduler from .pipeline_qwenimage import QwenImagePipeline from PIL import Image -from shared.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image def stitch_images(img1, img2): # Resize img2 to match img1's height @@ -103,8 +103,8 @@ def generate( n_prompt = None, sampling_steps: int = 20, input_ref_images = None, - image_guide= None, - image_mask= None, + input_frames= None, + input_masks= None, width= 832, height=480, guide_scale: float = 4, @@ -179,8 +179,10 @@ def generate( if n_prompt is None or len(n_prompt) == 0: n_prompt= "text, watermark, copyright, blurry, low resolution" - if image_guide is not None: - input_ref_images = [image_guide] + + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + if input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] elif input_ref_images is not None: # image stiching method stiched = input_ref_images[0] diff --git a/models/wan/animate/animate_utils.py b/models/wan/animate/animate_utils.py new file mode 100644 index 000000000..9474dce3e --- /dev/null +++ b/models/wan/animate/animate_utils.py @@ -0,0 +1,143 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import numbers +from peft import LoraConfig + + +def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"): + target_modules = [] + for name, module in transformer.named_modules(): + if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear): + target_modules.append(name) + + transformer_lora_config = LoraConfig( + r=rank, + lora_alpha=alpha, + init_lora_weights=init_lora_weights, + target_modules=target_modules, + ) + return transformer_lora_config + + + +class TensorList(object): + + def __init__(self, tensors): + """ + tensors: a list of torch.Tensor objects. No need to have uniform shape. + """ + assert isinstance(tensors, (list, tuple)) + assert all(isinstance(u, torch.Tensor) for u in tensors) + assert len(set([u.ndim for u in tensors])) == 1 + assert len(set([u.dtype for u in tensors])) == 1 + assert len(set([u.device for u in tensors])) == 1 + self.tensors = tensors + + def to(self, *args, **kwargs): + return TensorList([u.to(*args, **kwargs) for u in self.tensors]) + + def size(self, dim): + assert dim == 0, 'only support get the 0th size' + return len(self.tensors) + + def pow(self, *args, **kwargs): + return TensorList([u.pow(*args, **kwargs) for u in self.tensors]) + + def squeeze(self, dim): + assert dim != 0 + if dim > 0: + dim -= 1 + return TensorList([u.squeeze(dim) for u in self.tensors]) + + def type(self, *args, **kwargs): + return TensorList([u.type(*args, **kwargs) for u in self.tensors]) + + def type_as(self, other): + assert isinstance(other, (torch.Tensor, TensorList)) + if isinstance(other, torch.Tensor): + return TensorList([u.type_as(other) for u in self.tensors]) + else: + return TensorList([u.type(other.dtype) for u in self.tensors]) + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def device(self): + return self.tensors[0].device + + @property + def ndim(self): + return 1 + self.tensors[0].ndim + + def __getitem__(self, index): + return self.tensors[index] + + def __len__(self): + return len(self.tensors) + + def __add__(self, other): + return self._apply(other, lambda u, v: u + v) + + def __radd__(self, other): + return self._apply(other, lambda u, v: v + u) + + def __sub__(self, other): + return self._apply(other, lambda u, v: u - v) + + def __rsub__(self, other): + return self._apply(other, lambda u, v: v - u) + + def __mul__(self, other): + return self._apply(other, lambda u, v: u * v) + + def __rmul__(self, other): + return self._apply(other, lambda u, v: v * u) + + def __floordiv__(self, other): + return self._apply(other, lambda u, v: u // v) + + def __truediv__(self, other): + return self._apply(other, lambda u, v: u / v) + + def __rfloordiv__(self, other): + return self._apply(other, lambda u, v: v // u) + + def __rtruediv__(self, other): + return self._apply(other, lambda u, v: v / u) + + def __pow__(self, other): + return self._apply(other, lambda u, v: u ** v) + + def __rpow__(self, other): + return self._apply(other, lambda u, v: v ** u) + + def __neg__(self): + return TensorList([-u for u in self.tensors]) + + def __iter__(self): + for tensor in self.tensors: + yield tensor + + def __repr__(self): + return 'TensorList: \n' + repr(self.tensors) + + def _apply(self, other, op): + if isinstance(other, (list, tuple, TensorList)) or ( + isinstance(other, torch.Tensor) and ( + other.numel() > 1 or other.ndim > 1 + ) + ): + assert len(other) == len(self.tensors) + return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) + elif isinstance(other, numbers.Number) or ( + isinstance(other, torch.Tensor) and ( + other.numel() == 1 and other.ndim <= 1 + ) + ): + return TensorList([op(u, other) for u in self.tensors]) + else: + raise TypeError( + f'unsupported operand for *: "TensorList" and "{type(other)}"' + ) \ No newline at end of file diff --git a/models/wan/animate/face_blocks.py b/models/wan/animate/face_blocks.py new file mode 100644 index 000000000..8ddb829fa --- /dev/null +++ b/models/wan/animate/face_blocks.py @@ -0,0 +1,382 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from shared.attention import pay_attention + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + if use_context_parallel: + q = gather_forward(q, dim=1) + + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + # Compute attention. + # Size([batches, tokens, heads, head_features]) + qkv_list = [q, k, v] + del q,k,v + attn = pay_attention(qkv_list) + # attn = attention( + # q, + # k, + # v, + # max_seqlen_q=q.shape[1], + # batch_size=q.shape[0], + # ) + + attn = attn.reshape(*attn.shape[:2], -1) + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + # if use_context_parallel: + # attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output \ No newline at end of file diff --git a/models/wan/animate/model_animate.py b/models/wan/animate/model_animate.py new file mode 100644 index 000000000..d07f762ee --- /dev/null +++ b/models/wan/animate/model_animate.py @@ -0,0 +1,31 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy +from einops import rearrange +from typing import List +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec diff --git a/models/wan/animate/motion_encoder.py b/models/wan/animate/motion_encoder.py new file mode 100644 index 000000000..02b00408c --- /dev/null +++ b/models/wan/animate/motion_encoder.py @@ -0,0 +1,308 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype in [torch.bfloat16, torch.float16]: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + # motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + with torch.cuda.amp.autocast(dtype=torch.float32): + motion_feat = self.enc.enc_motion(img) + motion = self.dec.direction(motion_feat) + return motion \ No newline at end of file diff --git a/models/wan/any2video.py b/models/wan/any2video.py index c03752b58..41d6d6316 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -203,10 +203,7 @@ def __init__( self.use_timestep_transform = True def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) + ref_images = [ref_images] * len(frames) if masks is None: latents = self.vae.encode(frames, tile_size = tile_size) @@ -238,11 +235,7 @@ def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, over return cat_latents def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - + ref_images = [ref_images] * len(masks) result_masks = [] for mask, refs in zip(masks, ref_images): c, depth, height, width = mask.shape @@ -270,79 +263,6 @@ def vace_encode_masks(self, masks, ref_images=None): result_masks.append(mask) return result_masks - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): - image_sizes = [] - trim_video_guide = len(keep_video_guide_frames) - def conv_tensor(t, device): - return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames - if sub_src_mask is not None and sub_src_video is not None: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device) - # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[i][:, pos:pos+1] = 0 - src_mask[i][:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) - - - self.background_mask = None - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None and not torch.is_tensor(ref_img): - if j==0 and any_background_ref: - if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) - src_ref_images[i][j], self.background_mask[i] = fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True) - else: - src_ref_images[i][j], _ = fit_image_into_canvas(ref_img, image_size, 1, device) - if self.background_mask != None: - self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref - return src_video, src_mask, src_ref_images def get_vae_latents(self, ref_images, device, tile_size= 0): ref_vae_latents = [] @@ -369,7 +289,9 @@ def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=No def generate(self, input_prompt, input_frames= None, + input_frames2= None, input_masks = None, + input_masks2 = None, input_ref_images = None, input_ref_masks = None, input_faces = None, @@ -615,21 +537,22 @@ def generate(self, pose_pixels = input_frames * input_masks input_masks = 1. - input_masks pose_pixels -= input_masks - save_video(pose_pixels, "pose.mp4") pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) input_frames = input_frames * input_masks if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames if prefix_frames_count > 0: input_frames[:, :prefix_frames_count] = input_video input_masks[:, :prefix_frames_count] = 1 - save_video(input_frames, "input_frames.mp4") - save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + # save_video(pose_pixels, "pose.mp4") + # save_video(input_frames, "input_frames.mp4") + # save_video(input_masks, "input_masks.mp4", value_range=(0,1)) lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) msk = torch.concat([msk_ref, msk_control], dim=1) - clip_image_start = image_ref = convert_image_to_tensor(input_ref_images[0]).to(self.device) - lat_y = torch.concat(self.vae.encode([image_ref.unsqueeze(1).to(self.device), input_frames.to(self.device)], VAE_tile_size), dim=1) + image_ref = input_ref_images[0].to(self.device) + clip_image_start = image_ref.squeeze(1) + lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1) y = torch.concat([msk, lat_y]) kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) lat_y = msk = msk_control = msk_ref = pose_pixels = None @@ -701,12 +624,11 @@ def generate(self, # Phantom if phantom: - input_ref_images_neg = None - if input_ref_images != None: # Phantom Ref images - input_ref_images = self.get_vae_latents(input_ref_images, self.device) - input_ref_images_neg = torch.zeros_like(input_ref_images) - ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 - trim_frames = input_ref_images.shape[1] + lat_input_ref_images_neg = None + if input_ref_images is not None: # Phantom Ref images + lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device) + lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images) + ref_images_count = trim_frames = lat_input_ref_images.shape[1] if ti2v: if input_video is None: @@ -721,25 +643,23 @@ def generate(self, # Vace if vace : # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] + input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)]) + input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) + input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images] + input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks] ref_images_before = True - if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) - if self.background_mask != None: - color_reference_frame = input_ref_images[0][0].clone() - zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) - mbg = self.vace_encode_masks(self.background_mask, None) + if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None: + color_reference_frame = input_ref_images[0].clone() + zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size ) + mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None) for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): zz0[:, 0:1] = zzbg mm0[:, 0:1] = mmbg - - self.background_mask = zz0 = mm0 = zzbg = mmbg = None - z = self.vace_latent(z0, m0) - - ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + zz0 = mm0 = zzbg = mmbg = None + z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)] + ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0 context_scale = context_scale if context_scale != None else [1.0] * len(z) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) if overlapped_latents != None : @@ -747,15 +667,8 @@ def generate(self, extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0) if prefix_frames_count > 0: color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone() - - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - lat_h, lat_w = target_shape[-2:] - height = self.vae_stride[1] * lat_h - width = self.vae_stride[2] * lat_w - - else: - target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w) if multitalk: if audio_proj is None: @@ -860,7 +773,9 @@ def clear(): apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - + input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None + gc.collect() + torch.cuda.empty_cache() # denoising trans = self.model @@ -878,7 +793,7 @@ def clear(): kwargs.update({"t": timestep, "current_step": start_step_no + i}) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None - if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: + if denoising_strength < 1 and i <= injection_denoising_step: sigma = t / 1000 noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: @@ -912,8 +827,8 @@ def clear(): any_guidance = guide_scale != 1 if phantom: gen_args = { - "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + - [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), + "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + + [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), "context": [context, context_null, context_null] , } elif fantasy: diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index 7a5d3ea2d..31d5da6c8 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -21,6 +21,7 @@ def query_model_def(base_model_type, model_def): extra_model_def["fps"] =fps extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 20 + extra_model_def["latent_size"] = 4 extra_model_def["sliding_window"] = True extra_model_def["skip_layer_guidance"] = True extra_model_def["tea_cache"] = True diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index ff1570f3a..f2a8fb459 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -114,7 +114,6 @@ def query_model_def(base_model_type, model_def): "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], - "convert_image_guide_to_video" : True, "sample_solvers":[ ("unipc", "unipc"), ("euler", "euler"), @@ -175,6 +174,8 @@ def query_model_def(base_model_type, model_def): extra_model_def["forced_guide_mask_inputs"] = True extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" extra_model_def["background_ref_outpainted"] = False + extra_model_def["return_image_refs_tensor"] = True + extra_model_def["guide_inpaint_color"] = 0 @@ -196,15 +197,15 @@ def query_model_def(base_model_type, model_def): "letters_filter": "KFI", } - extra_model_def["lock_image_refs_ratios"] = True extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["pad_guide_video"] = True extra_model_def["guide_inpaint_color"] = 127.5 extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["return_image_refs_tensor"] = True if base_model_type in ["standin"]: - extra_model_def["lock_image_refs_ratios"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 extra_model_def["image_ref_choices"] = { "choices": [ ("No Reference Image", ""), @@ -480,6 +481,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "video_prompt_type": "PVBXAKI", "mask_expand": 20, + "audio_prompt_type_value": "R", }) if text_oneframe_overlap(base_model_type): diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 6e8a98bf2..bb2d5ffb6 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -32,6 +32,14 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) +def has_video_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".mp4"] + +def has_image_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] + def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math @@ -94,7 +102,7 @@ def get_video_info(video_path): return fps, width, height, frame_count -def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor: +def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor: """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" cap = cv2.VideoCapture(file_name) @@ -102,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool raise ValueError(f"Cannot open video: {file_name}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - + fps = round(cap.get(cv2.CAP_PROP_FPS)) + if target_fps is not None: + frame_no = round(target_fps * frame_no /fps) + # Handle out of bounds if frame_no >= total_frames or frame_no < 0: if return_last_if_missing: @@ -175,10 +186,15 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = 0): +def convert_tensor_to_image(t, frame_no = 0, mask_levels = False): if len(t.shape) == 4: t = t[:, frame_no] - return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) + if t.shape[0]== 1: + t = t.expand(3,-1,-1) + if mask_levels: + return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy()) + else: + return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): convert_tensor_to_image(tensor_image, frame_no).save(name) @@ -257,7 +273,7 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5 ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ): if rm_background: session = new_session() @@ -266,7 +282,7 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg for i, img in enumerate(img_list): width, height = img.size resized_mask = None - if fit_into_canvas == None or any_background_ref == 1 and i==0 or any_background_ref == 2: + if any_background_ref == 1 and i==0 or any_background_ref == 2: if outpainting_dims is not None and background_ref_outpainted: resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) elif img.size != (budget_width, budget_height): @@ -291,7 +307,10 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') - output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, + if return_tensor: + output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) + else: + output_list.append(resized_image) output_mask_list.append(resized_mask) return output_list, output_mask_list @@ -346,47 +365,46 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu return ref_img.to(device), canvas -def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, extract_guide_from_window_start = False, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None ): +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"): src_videos, src_masks = [], [] - inpaint_color = guide_inpaint_color/127.5 - 1 - prepend_count = pre_video_guide.shape[1] if not extract_guide_from_window_start and pre_video_guide is not None else 0 + inpaint_color_compressed = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0 for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): - src_video = src_mask = None - if cur_video_guide is not None: - src_video = cur_video_guide.permute(3, 0, 1, 2).float().div_(127.5).sub_(1.) # c, f, h, w - if cur_video_mask is not None and any_mask: - src_mask = cur_video_mask.permute(3, 0, 1, 2).float().div_(255)[0:1] # c, f, h, w - if pre_video_guide is not None and not extract_guide_from_window_start: + src_video, src_mask = cur_video_guide, cur_video_mask + if pre_video_guide is not None: src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) if any_mask: src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) - if src_video is None: - if any_guide_padding: - src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color, dtype = torch.float, device= "cpu") - if any_mask: - src_mask = torch.zeros_like(src_video[0:1]) - elif src_video.shape[1] < current_video_length: - if any_guide_padding: - src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color, dtype = src_video.dtype, device= src_video.device) ], dim=1) - if cur_video_mask is not None and any_mask: - src_mask = torch.cat([src_mask, torch.full( (1, current_video_length - src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) - else: - new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 - src_video = src_video[:, :new_num_frames] - if any_mask: - src_mask = src_mask[:, :new_num_frames] - - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[:, pos:pos+1] = inpaint_color - src_mask[:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[:, pos:pos+1], src_mask[:, pos:pos+1] = fit_image_into_canvas(frame, image_size, inpaint_color, device, True, outpainting_dims, return_mask= True) + if any_guide_padding: + if src_video is None: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device) + elif src_video.shape[1] < current_video_length: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1) + elif src_video is not None: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + + if any_mask and src_video is not None: + if src_mask is None: + src_mask = torch.ones_like(src_video[:1]) + elif src_mask.shape[1] < src_video.shape[1]: + src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + else: + src_mask = src_mask[:, :src_video.shape[1]] + + if src_video is not None : + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color_compressed + if any_mask: src_mask[:, pos:pos+1] = 1 + + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask) + if any_mask: src_mask[:, pos:pos+1] = msk src_videos.append(src_video) src_masks.append(src_mask) return src_videos, src_masks diff --git a/wgp.py b/wgp.py index 0e6a0b20e..6a7bca122 100644 --- a/wgp.py +++ b/wgp.py @@ -24,6 +24,7 @@ from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions +from shared.utils.utils import has_video_file_extension, has_image_file_extension from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture @@ -62,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.61" +WanGP_version = "8.7" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1942,7 +1943,8 @@ def get_model_min_frames_and_step(model_type): mode_def = get_model_def(model_type) frames_minimum = mode_def.get("frames_minimum", 5) frames_steps = mode_def.get("frames_steps", 4) - return frames_minimum, frames_steps + latent_size = mode_def.get("latent_size", frames_steps) + return frames_minimum, frames_steps, latent_size def get_model_fps(model_type): mode_def = get_model_def(model_type) @@ -3459,7 +3461,7 @@ def check(src, cond): if len(video_other_prompts) >0 : values += [video_other_prompts] labels += ["Other Prompts"] - if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"): + if len(video_outpainting) >0: values += [video_outpainting] labels += ["Outpainting"] video_sample_solver = configs.get("sample_solver", "") @@ -3532,6 +3534,11 @@ def convert_image(image): return cast(Image, ImageOps.exif_transpose(image)) def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'): + if isinstance(video_in, str) and has_image_file_extension(video_in): + video_in = Image.open(video_in) + if isinstance(video_in, Image.Image): + return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0) + from shared.utils.utils import resample import decord @@ -3653,19 +3660,22 @@ def get_preprocessor(process_type, inpaint_color): def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) : if not items: return [] - max_workers = 11 + import concurrent.futures start_time = time.time() # print(f"Preprocessus:{process_type} started") if process_type in ["prephase", "upsample"]: if wrap_in_list : items = [ [img] for img in items] - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} - results = [None] * len(items) - for future in concurrent.futures.as_completed(futures): - idx = futures[future] - results[idx] = future.result() + if max_workers == 1: + results = [image_processor(img) for img in items] + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} + results = [None] * len(items) + for future in concurrent.futures.as_completed(futures): + idx = futures[future] + results[idx] = future.result() if wrap_in_list: results = [ img[0] for img in results] @@ -3677,55 +3687,6 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis return results -def preprocess_image_with_mask(input_image, input_mask, height, width, fit_canvas = False, fit_crop = False, block_size= 16, expand_scale = 2, outpainting_dims = None, inpaint_color = 127): - frame_width, frame_height = input_image.size - - if fit_crop: - input_image = rescale_and_crop(input_image, width, height) - if input_mask is not None: - input_mask = rescale_and_crop(input_mask, width, height) - return input_image, input_mask - - if outpainting_dims != None: - if fit_canvas != None: - frame_height, frame_width = get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_dims) - else: - frame_height, frame_width = height, width - - if fit_canvas != None: - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas = fit_canvas, block_size = block_size) - - if outpainting_dims != None: - final_height, final_width = height, width - height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) - - if fit_canvas != None or outpainting_dims != None: - input_image = input_image.resize((width, height), resample=Image.Resampling.LANCZOS) - if input_mask is not None: - input_mask = input_mask.resize((width, height), resample=Image.Resampling.LANCZOS) - - if expand_scale != 0 and input_mask is not None: - kernel_size = abs(expand_scale) - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - op_expand = cv2.dilate if expand_scale > 0 else cv2.erode - input_mask = np.array(input_mask) - input_mask = op_expand(input_mask, kernel, iterations=3) - input_mask = Image.fromarray(input_mask) - - if outpainting_dims != None: - inpaint_color = inpaint_color / 127.5-1 - image = convert_image_to_tensor(input_image) - full_frame= torch.full( (image.shape[0], final_height, final_width), inpaint_color, dtype= torch.float, device= image.device) - full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = image - input_image = convert_tensor_to_image(full_frame) - - if input_mask is not None: - mask = convert_image_to_tensor(input_mask) - full_frame= torch.full( (mask.shape[0], final_height, final_width), 1, dtype= torch.float, device= mask.device) - full_frame[:, margin_top:margin_top+height, margin_left:margin_left+width] = mask - input_mask = convert_tensor_to_image(full_frame) - - return input_image, input_mask def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): if not input_video_path or max_frames <= 0: return None, None @@ -3780,6 +3741,8 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) return face_tensor +def get_default_workers(): + return os.cpu_count()/ 2 def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): @@ -3906,8 +3869,8 @@ def prep_prephase(frame_idx): return (target_frame, frame, mask) else: return (target_frame, None, None) - - proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False) + max_workers = get_default_workers() + proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers) proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists) for frame_idx, frame_group in enumerate(proc_lists): proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group @@ -3916,11 +3879,11 @@ def prep_prephase(frame_idx): mask_video = None if preproc2 != None: - proc_list2 = process_images_multithread(preproc2, proc_list, process_type2) + proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers) #### to be finished ...or not - proc_list = process_images_multithread(preproc, proc_list, process_type) + proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers) if any_mask: - proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask) + proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers) else: proc_list_outside = proc_mask = len(proc_list) * [None] @@ -3938,7 +3901,7 @@ def prep_prephase(frame_idx): full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device) full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask mask = full_frame - masks.append(mask) + masks.append(mask[:, :, 0:1].clone()) else: masked_frame = processed_img @@ -3958,13 +3921,13 @@ def prep_prephase(frame_idx): proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None - if args.save_masks: - from preprocessing.dwpose.pose import save_one_video - saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] - save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) - if any_mask: - saved_masks = [mask.cpu().numpy() for mask in masks ] - save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) + # if args.save_masks: + # from preprocessing.dwpose.pose import save_one_video + # saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] + # save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) + # if any_mask: + # saved_masks = [mask.cpu().numpy() for mask in masks ] + # save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) preproc = None preproc_outside = None gc.collect() @@ -3972,8 +3935,10 @@ def prep_prephase(frame_idx): if pad_frames > 0: masked_frames = masked_frames[0] * pad_frames + masked_frames if any_mask: masked_frames = masks[0] * pad_frames + masks + masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.) + masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None - return torch.stack(masked_frames), torch.stack(masks) if any_mask else None + return masked_frames, masks def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16): @@ -4102,7 +4067,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling): frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] def upsample_frames(frame): return resize_lanczos(frame, h, w).unsqueeze(1) - sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1) + sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1) frames_to_upsample = None return sample @@ -4609,17 +4574,13 @@ def remove_temp_filenames(temp_filenames_list): batch_size = 1 temp_filenames_list = [] - convert_image_guide_to_video = model_def.get("convert_image_guide_to_video", False) - if convert_image_guide_to_video: - if image_guide is not None and isinstance(image_guide, Image.Image): - video_guide = convert_image_to_video(image_guide) - temp_filenames_list.append(video_guide) - image_guide = None + if image_guide is not None and isinstance(image_guide, Image.Image): + video_guide = image_guide + image_guide = None - if image_mask is not None and isinstance(image_mask, Image.Image): - video_mask = convert_image_to_video(image_mask) - temp_filenames_list.append(video_mask) - image_mask = None + if image_mask is not None and isinstance(image_mask, Image.Image): + video_mask = image_mask + image_mask = None if model_def.get("no_background_removal", False): remove_background_images_ref = 0 @@ -4711,22 +4672,12 @@ def remove_temp_filenames(temp_filenames_list): device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) - i2v = test_class_i2v(model_type) - diffusion_forcing = "diffusion_forcing" in model_filename - t2v = base_model_type in ["t2v"] - ltxv = "ltxv" in model_filename - vace = test_vace_module(base_model_type) - hunyuan_t2v = "hunyuan_video_720" in model_filename - hunyuan_i2v = "hunyuan_video_i2v" in model_filename hunyuan_custom = "hunyuan_video_custom" in model_filename hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - standin = model_def.get("standin_class", False) - infinitetalk = base_model_type in ["infinitetalk"] - animate = base_model_type in ["animate"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4763,9 +4714,9 @@ def remove_temp_filenames(temp_filenames_list): sliding_window_size = current_video_length reuse_frames = 0 - _, latent_size = get_model_min_frames_and_step(model_type) - if diffusion_forcing: latent_size = 4 + _, _, latent_size = get_model_min_frames_and_step(model_type) original_image_refs = image_refs + image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified # image_refs = None # nb_frames_positions= 0 # Output Video Ratio Priorities: @@ -4889,6 +4840,7 @@ def remove_temp_filenames(temp_filenames_list): initial_total_windows = 0 discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length + nb_frames_positions = 0 if sliding_window: initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames) current_video_length = sliding_window_size @@ -4907,7 +4859,7 @@ def remove_temp_filenames(temp_filenames_list): if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_mask = src_ref_images = new_image_guide = new_image_mask = src_faces = None + src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4963,7 +4915,6 @@ def remove_temp_filenames(temp_filenames_list): return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images, src_ref_masks = image_refs, None image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: @@ -5020,7 +4971,7 @@ def remove_temp_filenames(temp_filenames_list): if len(pos) > 0: if pos in ["L", "l"]: cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length - if cur_end_pos >= last_frame_no and not joker_used: + if cur_end_pos >= last_frame_no-1 and not joker_used: joker_used = True cur_end_pos = last_frame_no -1 project_window_no += 1 @@ -5036,141 +4987,53 @@ def remove_temp_filenames(temp_filenames_list): frames_to_inject[pos] = image_refs[i] + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None if video_guide is not None: keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + extra_control_frames = model_def.get("extra_control_frames", 0) + if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] guide_frames_extract_count = len(keep_frames_parsed) + # Extract Faces to video if "B" in video_prompt_type: send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) if src_faces is not None and src_faces.shape[1] < current_video_length: src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) - if vace or animate: - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None - context_scale = [ control_net_weight] - if "V" in video_prompt_type: - process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) - preprocess_type, preprocess_type2 = "raw", None - for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): - if process_num == 0: - preprocess_type = process_map_video_guide.get(process_letter, "raw") - else: - preprocess_type2 = process_map_video_guide.get(process_letter, None) - status_info = "Extracting " + processes_names[preprocess_type] - extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) - if len(extra_process_list) == 1: - status_info += " and " + processes_names[extra_process_list[0]] - elif len(extra_process_list) == 2: - status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] - if preprocess_type2 is not None: - context_scale = [ control_net_weight /2, control_net_weight2 /2] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - inpaint_color = 0 if preprocess_type=="pose" else guide_inpaint_color - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color ) - if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) - - if video_guide_processed != None: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy()) - if video_guide_processed2 != None: - refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] - if video_mask_processed != None: - refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - - frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] - - if not vace and (any_letters(video_prompt_type ,"FV") or model_def.get("forced_guide_mask_inputs", False)): - any_mask = True - any_guide_padding = model_def.get("pad_guide_video", False) - from shared.utils.utils import prepare_video_guide_and_mask - src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed, video_guide_processed2], - [video_mask_processed, video_mask_processed2], - pre_video_guide, image_size, current_video_length, latent_size, - any_mask, any_guide_padding, guide_inpaint_color, extract_guide_from_window_start, - keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) - - src_video, src_video2 = src_videos - src_mask, src_mask2 = src_masks - if src_video is None: - abort = True - break - if src_faces is not None: - if src_faces.shape[1] < src_video.shape[1]: - src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) - else: - src_faces = src_faces[:, :src_video.shape[1]] - if args.save_masks: - save_video( src_video, "masked_frames.mp4", fps) - if src_video2 is not None: - save_video( src_video2, "masked_frames2.mp4", fps) - if any_mask: - save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) - - elif ltxv: - preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") - status_info = "Extracting " + processes_names[preprocess_type] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - # start one frame ealier to facilitate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) - if src_video != None: - src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - refresh_preview["video_mask"] = None - src_video = src_video.permute(3, 0, 1, 2) - src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - - elif hunyuan_custom_edit: - if "P" in video_prompt_type: - progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] - else: - progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] - - send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - if src_mask != None: - refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) - - elif "R" in video_prompt_type: # sparse video to video - src_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, return_PIL = True) - src_image, _, _ = calculate_dimensions_and_resize_image(src_image, image_size[0], image_size[1 ], sample_fit_canvas, fit_crop, block_size = block_size) - refresh_preview["video_guide"] = src_image - src_video = convert_image_to_tensor(src_image).unsqueeze(1) - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - - else: # video to video - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size= block_size) - if video_guide_processed is None: - src_video = pre_video_guide + # Sparse Video to Video + sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None + + # Generic Video Preprocessing + process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) + preprocess_type, preprocess_type2 = "raw", None + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): + if process_num == 0: + preprocess_type = process_map_video_guide.get(process_letter, "raw") else: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) - if pre_video_guide != None: - src_video = torch.cat( [pre_video_guide, src_video], dim=1) - elif image_guide is not None: - new_image_guide, new_image_mask = preprocess_image_with_mask(image_guide, image_mask, image_size[0], image_size[1], fit_canvas = sample_fit_canvas, fit_crop= fit_crop, block_size= block_size, expand_scale = mask_expand, outpainting_dims=outpainting_dims) - if sample_fit_canvas is not None: - image_size = (new_image_guide.size[1], new_image_guide.size[0]) + preprocess_type2 = process_map_video_guide.get(process_letter, None) + status_info = "Extracting " + processes_names[preprocess_type] + extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) + if len(extra_process_list) == 1: + status_info += " and " + processes_names[extra_process_list[0]] + elif len(extra_process_list) == 2: + status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] + context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight] + if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)]) + inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size ) + if preprocess_type2 != None: + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size ) + + if video_guide_processed is not None and sample_fit_canvas is not None: + image_size = video_guide_processed.shape[-2:] sample_fit_canvas = None - refresh_preview["image_guide"] = new_image_guide - if new_image_mask is not None: - refresh_preview["image_mask"] = new_image_mask if window_no == 1 and image_refs is not None and len(image_refs) > 0: if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : @@ -5192,45 +5055,68 @@ def remove_temp_filenames(temp_filenames_list): image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0]) refresh_preview["image_refs"] = image_refs - if len(image_refs) > nb_frames_positions: + if len(image_refs) > nb_frames_positions: + src_ref_images = image_refs[nb_frames_positions:] if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) - # keep image ratios if there is a background image ref (we will let the model preprocessor decide what to do) but remove bg if requested - image_refs[nb_frames_positions:], src_ref_masks = resize_and_remove_background(image_refs[nb_frames_positions:] , image_size[1], image_size[0], + + src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0], remove_background_images_ref > 0, any_background_ref, - fit_into_canvas= 0 if (any_background_ref > 0 or model_def.get("lock_image_refs_ratios", False)) else 1, + fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1), block_size=block_size, outpainting_dims =outpainting_dims, - background_ref_outpainted = model_def.get("background_ref_outpainted", True) ) - refresh_preview["image_refs"] = image_refs - + background_ref_outpainted = model_def.get("background_ref_outpainted", True), + return_tensor= model_def.get("return_image_refs_tensor", False) ) + - if vace : - image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications - - src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], - [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2], - [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], - current_video_length, image_size = image_size, device ="cpu", - keep_video_guide_frames=keep_frames_parsed, - pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], - inject_frames= frames_to_inject_parsed, - outpainting_dims = outpainting_dims, - any_background_ref = any_background_ref - ) - if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no + 0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] - if any_background_ref: - new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): + any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False) + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), + [video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]), + None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, + image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None + if len(src_videos) == 1: + src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None + else: + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + src_videos = src_masks = None + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) else: - new_image_refs += image_refs[nb_frames_positions:] - refresh_preview["image_refs"] = new_image_refs - new_image_refs = None - - if sample_fit_canvas != None: - image_size = src_video[0].shape[-2:] - sample_fit_canvas = None - + src_faces = src_faces[:, :src_video.shape[1]] + if video_guide is not None or len(frames_to_inject_parsed) > 0: + if args.save_masks: + if src_video is not None: save_video( src_video, "masked_frames.mp4", fps) + if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + if video_guide is not None: + preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) + refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) + if src_video2 is not None: + refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)] + if src_mask is not None and video_mask is not None: + refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True) + + if src_ref_images is not None or nb_frames_positions: + if len(frames_to_inject_parsed): + new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] + else: + new_image_refs = [] + if src_ref_images is not None: + new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ] + refresh_preview["image_refs"] = new_image_refs + new_image_refs = None if len(refresh_preview) > 0: new_inputs= locals() @@ -5339,8 +5225,6 @@ def set_header_text(txt): pre_video_frame = pre_video_frame, original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, - image_guide= new_image_guide, - image_mask= new_image_mask, outpainting_dims = outpainting_dims, ) except Exception as e: @@ -6320,8 +6204,11 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["image_refs_relative_size"] if not vace: - pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"] + pop += ["frames_positions", "control_net_weight", "control_net_weight2"] + if model_def.get("video_guide_outpainting", None) is None: + pop += ["video_guide_outpainting"] + if not (vace or t2v): pop += ["min_frames_if_references"] @@ -6506,13 +6393,6 @@ def eject_video_from_gallery(state, input_file_list, choice): choice = min(choice, len(file_list)) return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) -def has_video_file_extension(filename): - extension = os.path.splitext(filename)[-1].lower() - return extension in [".mp4"] - -def has_image_file_extension(filename): - extension = os.path.splitext(filename)[-1].lower() - return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) if files_to_load == None: @@ -7881,7 +7761,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl elif recammaster: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) else: - min_frames, frames_step = get_model_min_frames_and_step(base_model_type) + min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type) current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97) @@ -8059,7 +7939,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)") - with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row: + with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row: gr.Markdown("You may transfer the existing audio tracks of a Control Video") audio_prompt_type_remux = gr.Dropdown( choices=[ @@ -8284,16 +8164,16 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info()) with gr.Row(**default_visibility) as video_buttons_row: video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") - video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source) + video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm") with gr.Row(**default_visibility) as image_buttons_row: video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image ) video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image) - video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False) video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image) + video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm") with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"): From 356e10ce71dc969d11eec42f05a0dfdfada20ad4 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 23 Sep 2025 23:28:43 +0200 Subject: [PATCH 031/155] fixed default sound --- models/wan/wan_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index f2a8fb459..12ddfed45 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -481,7 +481,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "video_prompt_type": "PVBXAKI", "mask_expand": 20, - "audio_prompt_type_value": "R", + "audio_prompt_type": "R", }) if text_oneframe_overlap(base_model_type): From 625b50aefd6cb48f09da5bb102fbe900acb9be84 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 24 Sep 2025 00:12:22 +0200 Subject: [PATCH 032/155] fixed lucy edit fast wan, lora missing --- README.md | 2 ++ defaults/lucy_edit_fastwan.json | 2 +- wgp.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6c338ac7c..8cd8e9c6a 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,8 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. +In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the Video Mask of the person from which you want to extract the motion. + - **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json index c67c795d7..d5d47c814 100644 --- a/defaults/lucy_edit_fastwan.json +++ b/defaults/lucy_edit_fastwan.json @@ -5,7 +5,7 @@ "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", "URLs": "lucy_edit", "group": "wan2_2", - "loras": "ti2v_2_2" + "loras": "ti2v_2_2_fastwan" }, "prompt": "change the clothes to red", "video_length": 81, diff --git a/wgp.py b/wgp.py index 6a7bca122..43f5e8006 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.7" +WanGP_version = "8.71" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None From 2cbcb9523eab5af691f814be399657856742ebe3 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 24 Sep 2025 08:16:24 +0200 Subject: [PATCH 033/155] fix sparse video error --- README.md | 7 ++++--- docs/INSTALLATION.md | 2 +- requirements.txt | 2 +- wgp.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 8cd8e9c6a..8cee8346b 100644 --- a/README.md +++ b/README.md @@ -23,12 +23,13 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models ### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena ! So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: -- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. +- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can either *Replace* a person in a Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX i2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. -In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the Video Mask of the person from which you want to extract the motion. +In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion. -- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. +- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. +*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora ### September 15 2025: WanGP v8.6 - Attack of the Clones diff --git a/docs/INSTALLATION.md b/docs/INSTALLATION.md index 9f66422a1..361f26684 100644 --- a/docs/INSTALLATION.md +++ b/docs/INSTALLATION.md @@ -27,7 +27,7 @@ conda activate wan2gp ### Step 2: Install PyTorch ```shell -# Install PyTorch 2.7.0 with CUDA 12.4 +# Install PyTorch 2.7.0 with CUDA 12.8 pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 ``` diff --git a/requirements.txt b/requirements.txt index e6c85fda7..d6f75f74b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,7 +52,7 @@ matplotlib # Utilities ftfy piexif -pynvml +nvidia-ml-py misaki # Optional / commented out diff --git a/wgp.py b/wgp.py index 43f5e8006..d48c7fa29 100644 --- a/wgp.py +++ b/wgp.py @@ -4859,7 +4859,7 @@ def remove_temp_filenames(temp_filenames_list): if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None + src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) From 0a58ef6981e3343cbca6b3e12380148de281d91f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 24 Sep 2025 16:28:05 +0200 Subject: [PATCH 034/155] fixed sparse crash --- README.md | 2 +- wgp.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 6c338ac7c..6b005d8dc 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models ### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena ! So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: -- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion tranfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. +- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. - **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. diff --git a/wgp.py b/wgp.py index 6a7bca122..11da9679d 100644 --- a/wgp.py +++ b/wgp.py @@ -4859,7 +4859,7 @@ def remove_temp_filenames(temp_filenames_list): if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = None + src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) From 76f86c43ad12914626cf6d3ad623ba1f6f551640 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 24 Sep 2025 23:38:27 +0200 Subject: [PATCH 035/155] add qwen edit plus support --- README.md | 6 +- defaults/qwen_image_edit_plus_20B.json | 17 +++ models/flux/flux_handler.py | 2 +- models/qwen/pipeline_qwenimage.py | 163 ++++++++++++++----------- models/qwen/qwen_handler.py | 21 +++- models/qwen/qwen_main.py | 21 ++-- preprocessing/matanyone/app.py | 4 +- shared/utils/utils.py | 4 +- wgp.py | 35 ++++-- 9 files changed, 171 insertions(+), 102 deletions(-) create mode 100644 defaults/qwen_image_edit_plus_20B.json diff --git a/README.md b/README.md index 8cee8346b..b8c92b6e2 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### September 23 2025: WanGP v8.7 - Here Are Two New Contenders in the Vace Arena ! +### September 24 2025: WanGP v8.72 - Here Are ~~Two~~Three New Contenders in the Vace Arena ! So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can either *Replace* a person in a Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX i2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. @@ -29,7 +29,11 @@ In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* - **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. +Also because I wanted to spoil you: +- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version. + *Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora +*Update 8.72*: shadow drop of Qwen Edit Plus ### September 15 2025: WanGP v8.6 - Attack of the Clones diff --git a/defaults/qwen_image_edit_plus_20B.json b/defaults/qwen_image_edit_plus_20B.json new file mode 100644 index 000000000..e10deb24b --- /dev/null +++ b/defaults/qwen_image_edit_plus_20B.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Qwen Image Edit Plus 20B", + "architecture": "qwen_image_edit_plus_20B", + "description": "Qwen Image Edit Plus is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. This model is optimized to combine multiple Subjects & Objects.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_plus_20B_quanto_bf16_int8.safetensors" + ], + "preload_URLs": "qwen_image_edit_20B", + "attention": { + "<89": "sdpa" + } + }, + "prompt": "add a hat", + "resolution": "1024x1024", + "batch_size": 1 +} \ No newline at end of file diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index 83de7c3fc..471c339d4 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -28,7 +28,7 @@ def query_model_def(base_model_type, model_def): extra_model_def["any_image_refs_relative_size"] = True extra_model_def["no_background_removal"] = True extra_model_def["image_ref_choices"] = { - "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), + "choices":[("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), ("Up to two Images are Style Images", "KIJ")], "default": "KI", "letters_filter": "KIJ", diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 14728862d..85934b79f 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -200,7 +200,8 @@ def __init__( self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = 1024 if processor is not None: - self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + # self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.prompt_template_encode_start_idx = 64 else: self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" @@ -232,6 +233,21 @@ def _get_qwen_prompt_embeds( txt = [template.format(e) for e in prompt] if self.processor is not None and image is not None: + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + model_inputs = self.processor( text=txt, images=image, @@ -464,7 +480,7 @@ def disable_vae_tiling(self): def prepare_latents( self, - image, + images, batch_size, num_channels_latents, height, @@ -482,24 +498,30 @@ def prepare_latents( shape = (batch_size, num_channels_latents, 1, height, width) image_latents = None - if image is not None: - image = image.to(device=device, dtype=dtype) - if image.shape[1] != self.latent_channels: - image_latents = self._encode_vae_image(image=image, generator=generator) - else: - image_latents = image - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) + if images is not None and len(images ) > 0: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) - image_latents = self._pack_latents(image_latents) + image_latents = self._pack_latents(image_latents) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -568,6 +590,7 @@ def __call__( joint_pass= True, lora_inpaint = False, outpainting_dims = None, + qwen_edit_plus = False, ): r""" Function invoked when calling the pipeline for generation. @@ -683,61 +706,54 @@ def __call__( batch_size = prompt_embeds.shape[0] device = "cuda" - prompt_image = None + condition_images = [] + vae_image_sizes = [] + vae_images = [] image_mask_latents = None - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): - image = image[0] if isinstance(image, list) else image - image_height, image_width = self.image_processor.get_default_height_width(image) - aspect_ratio = image_width / image_height - if False : - _, image_width, image_height = min( - (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS - ) - image_width = image_width // multiple_of * multiple_of - image_height = image_height // multiple_of * multiple_of - ref_height, ref_width = 1568, 672 - - if image_mask is None: - if height * width < ref_height * ref_width: ref_height , ref_width = height , width - if image_height * image_width > ref_height * ref_width: - image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) - if (image_width,image_height) != image.size: - image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) - elif not lora_inpaint: - # _, image_width, image_height = min( - # (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS - # ) - image_height, image_width = calculate_new_dimensions(height, width, image_height, image_width, False, block_size=multiple_of) - # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) - height, width = image_height, image_width - image_mask_latents = convert_image_to_tensor(image_mask.resize((width // 8, height // 8), resample=Image.Resampling.LANCZOS)) - image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] - image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) - # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") - image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1) - image_mask_latents = self._pack_latents(image_mask_latents) - - prompt_image = image - if image.size != (image_width, image_height): - image = image.resize((image_width, image_height), resample=Image.Resampling.LANCZOS) - - image = convert_image_to_tensor(image) - if lora_inpaint: - image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1] - image_mask_latents = None - green = torch.tensor([-1.0, 1.0, -1.0]).to(image) - green_image = green[:, None, None] .expand_as(image) - image = torch.where(image_mask_rebuilt > 0, green_image, image) - prompt_image = convert_tensor_to_image(image) - image = image.unsqueeze(0).unsqueeze(2) - # image.save("nnn.png") + ref_size = 1024 + ref_text_encoder_size = 384 if qwen_edit_plus else 1024 + if image is not None: + if not isinstance(image, list): image = [image] + if height * width < ref_size * ref_size: ref_size = round(math.sqrt(height * width)) + for ref_no, img in enumerate(image): + image_width, image_height = img.size + any_mask = ref_no == 0 and image_mask is not None + if (image_height * image_width > ref_size * ref_size) and not any_mask: + vae_height, vae_width =calculate_new_dimensions(ref_size, ref_size, image_height, image_width, False, block_size=multiple_of) + else: + vae_height, vae_width = image_height, image_width + vae_width = vae_width // multiple_of * multiple_of + vae_height = vae_height // multiple_of * multiple_of + vae_image_sizes.append((vae_width, vae_height)) + condition_height, condition_width =calculate_new_dimensions(ref_text_encoder_size, ref_text_encoder_size, image_height, image_width, False, block_size=multiple_of) + condition_images.append(img.resize((condition_width, condition_height), resample=Image.Resampling.LANCZOS) ) + if img.size != (vae_width, vae_height): + img = img.resize((vae_width, vae_height), resample=Image.Resampling.LANCZOS) + if any_mask : + if lora_inpaint: + image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1] + img = convert_image_to_tensor(img) + green = torch.tensor([-1.0, 1.0, -1.0]).to(img) + green_image = green[:, None, None] .expand_as(img) + img = torch.where(image_mask_rebuilt > 0, green_image, img) + img = convert_tensor_to_image(img) + else: + image_mask_latents = convert_image_to_tensor(image_mask.resize((vae_width // 8, vae_height // 8), resample=Image.Resampling.LANCZOS)) + image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] + image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) + # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1) + image_mask_latents = self._pack_latents(image_mask_latents) + # img.save("nnn.png") + vae_images.append( convert_image_to_tensor(img).unsqueeze(0).unsqueeze(2) ) + has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( - image=prompt_image, + image=condition_images, prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, @@ -747,7 +763,7 @@ def __call__( ) if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( - image=prompt_image, + image=condition_images, prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, @@ -763,7 +779,7 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.in_channels // 4 latents, image_latents = self.prepare_latents( - image, + vae_images, batch_size * num_images_per_prompt, num_channels_latents, height, @@ -779,7 +795,12 @@ def __call__( img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), - (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + # (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] ] * batch_size else: @@ -971,7 +992,7 @@ def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): latents = latents / latents_std + latents_mean output_image = self.vae.decode(latents, return_dict=False)[0][:, :, 0] if image_mask is not None and not lora_inpaint : #not (lora_inpaint and outpainting_dims is not None): - output_image = image.squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(image) * image_mask_rebuilt + output_image = vae_images[0].squeeze(2) * (1 - image_mask_rebuilt) + output_image.to(vae_images[0] ) * image_mask_rebuilt return output_image diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index cc6a76484..4fcaa3ba8 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -20,7 +20,7 @@ def query_model_def(base_model_type, model_def): "fit_into_canvas_image_refs": 0, } - if base_model_type in ["qwen_image_edit_20B"]: + if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]: extra_model_def["inpaint_support"] = True extra_model_def["image_ref_choices"] = { "choices": [ @@ -42,11 +42,20 @@ def query_model_def(base_model_type, model_def): "image_modes" : [2], } + if base_model_type in ["qwen_image_edit_plus_20B"]: + extra_model_def["guide_preprocessing"] = { + "selection": ["", "PV", "SV", "CV"], + } + + extra_model_def["mask_preprocessing"] = { + "selection": ["", "A"], + "visible": False, + } return extra_model_def @staticmethod def query_supported_types(): - return ["qwen_image_20B", "qwen_image_edit_20B"] + return ["qwen_image_20B", "qwen_image_edit_20B", "qwen_image_edit_plus_20B"] @staticmethod def query_family_maps(): @@ -113,9 +122,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "denoising_strength" : 1., "model_mode" : 0, }) + elif base_model_type in ["qwen_image_edit_plus_20B"]: + ui_defaults.update({ + "video_prompt_type": "I", + "denoising_strength" : 1., + "model_mode" : 0, + }) def validate_generative_settings(base_model_type, model_def, inputs): - if base_model_type in ["qwen_image_edit_20B"]: + if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]: model_mode = inputs["model_mode"] denoising_strength= inputs["denoising_strength"] video_guide_outpainting= inputs["video_guide_outpainting"] diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index abd5c5cfe..8c4514fa2 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -51,10 +51,10 @@ def __init__( transformer_filename = model_filename[0] processor = None tokenizer = None - if base_model_type == "qwen_image_edit_20B": + if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]: processor = Qwen2VLProcessor.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) - + self.base_model_type = base_model_type base_config_file = "configs/qwen_image_20B.json" with open(base_config_file, 'r', encoding='utf-8') as f: @@ -173,7 +173,7 @@ def generate( self.vae.tile_latent_min_height = VAE_tile_size[1] self.vae.tile_latent_min_width = VAE_tile_size[1] - + qwen_edit_plus = self.base_model_type in ["qwen_image_edit_plus_20B"] self.vae.enable_slicing() # width, height = aspect_ratios["16:9"] @@ -182,17 +182,19 @@ def generate( image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) if input_frames is not None: - input_ref_images = [convert_tensor_to_image(input_frames) ] - elif input_ref_images is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] + ([] if input_ref_images is None else input_ref_images ) + + if input_ref_images is not None: # image stiching method stiched = input_ref_images[0] if "K" in video_prompt_type : w, h = input_ref_images[0].size height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - for new_img in input_ref_images[1:]: - stiched = stitch_images(stiched, new_img) - input_ref_images = [stiched] + if not qwen_edit_plus: + for new_img in input_ref_images[1:]: + stiched = stitch_images(stiched, new_img) + input_ref_images = [stiched] image = self.pipeline( prompt=input_prompt, @@ -212,7 +214,8 @@ def generate( generator=torch.Generator(device="cuda").manual_seed(seed), lora_inpaint = image_mask is not None and model_mode == 1, outpainting_dims = outpainting_dims, - ) + qwen_edit_plus = qwen_edit_plus, + ) if image is None: return None return image.transpose(0, 1) diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index d40811d51..df71b40a9 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -21,6 +21,7 @@ from .matanyone.inference.inference_core import InferenceCore from .matanyone_wrapper import matanyone from shared.utils.audio_video import save_video, save_image +from mmgp import offload arg_device = "cuda" arg_sam_model_type="vit_h" @@ -539,7 +540,7 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive file_name = ".".join(file_name.split(".")[:-1]) from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files - source_audio_tracks, audio_metadata = extract_audio_tracks(video_input) + source_audio_tracks, audio_metadata = extract_audio_tracks(video_input, verbose= offload.default_verboseLevel ) output_fg_path = f"./mask_outputs/{file_name}_fg.mp4" output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4" if len(source_audio_tracks) == 0: @@ -679,7 +680,6 @@ def load_unload_models(selected): } # os.path.join('.') - from mmgp import offload # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".") sam_checkpoint = None diff --git a/shared/utils/utils.py b/shared/utils/utils.py index bb2d5ffb6..00178aad1 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -321,7 +321,7 @@ def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu ref_width, ref_height = ref_img.size if (ref_height, ref_width) == image_size and outpainting_dims == None: ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - canvas = torch.zeros_like(ref_img) if return_mask else None + canvas = torch.zeros_like(ref_img[:1]) if return_mask else None else: if outpainting_dims != None: final_height, final_width = image_size @@ -374,7 +374,7 @@ def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, im if pre_video_guide is not None: src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) if any_mask: - src_mask = torch.zeros_like(pre_video_guide[0:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[0:1]), src_mask], dim=1) + src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1) if any_guide_padding: if src_video is None: diff --git a/wgp.py b/wgp.py index d48c7fa29..f92cf3334 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.71" +WanGP_version = "8.72" settings_version = 2.35 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -4987,7 +4987,7 @@ def remove_temp_filenames(temp_filenames_list): frames_to_inject[pos] = image_refs[i] - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = sparse_video_image = None if video_guide is not None: keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: @@ -6566,11 +6566,11 @@ def switch_image_mode(state): inpaint_support = model_def.get("inpaint_support", False) if inpaint_support: if image_mode == 1: - video_prompt_type = del_in_sequence(video_prompt_type, "VAG") + video_prompt_type = del_in_sequence(video_prompt_type, "VAG" + all_guide_processes) video_prompt_type = add_to_sequence(video_prompt_type, "KI") elif image_mode == 2: + video_prompt_type = del_in_sequence(video_prompt_type, "KI" + all_guide_processes) video_prompt_type = add_to_sequence(video_prompt_type, "VAG") - video_prompt_type = del_in_sequence(video_prompt_type, "KI") ui_defaults["video_prompt_type"] = video_prompt_type return str(time.time()) @@ -6965,10 +6965,11 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) return video_prompt_type +all_guide_processes ="PDESLCMUVB" def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): old_video_prompt_type = video_prompt_type - video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMUVB") + video_prompt_type = del_in_sequence(video_prompt_type, all_guide_processes) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type model_type = state["model_type"] @@ -6978,8 +6979,12 @@ def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt image_outputs = image_mode > 0 keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) - - return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible) + mask_preprocessing = model_def.get("mask_preprocessing", None) + if mask_preprocessing is not None: + mask_selector_visible = mask_preprocessing.get("visible", True) + else: + mask_selector_visible = True + return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible) def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode): @@ -7391,7 +7396,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non any_start_image = any_end_image = any_reference_image = any_image_mask = False v2i_switch_supported = (vace or t2v or standin) and not image_outputs ti2v_2_2 = base_model_type in ["ti2v_2_2"] - gallery_height = 350 + gallery_height = 550 def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ): with gr.Row(visible = visible) as gallery_row: gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode ) @@ -7459,8 +7464,12 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) any_control_video = any_control_image = False - guide_preprocessing = model_def.get("guide_preprocessing", None) - mask_preprocessing = model_def.get("mask_preprocessing", None) + if image_mode_value ==2: + guide_preprocessing = { "selection": ["V", "VG"]} + mask_preprocessing = { "selection": ["A"]} + else: + guide_preprocessing = model_def.get("guide_preprocessing", None) + mask_preprocessing = model_def.get("mask_preprocessing", None) guide_custom_choices = model_def.get("guide_custom_choices", None) image_ref_choices = model_def.get("image_ref_choices", None) @@ -7504,7 +7513,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") video_prompt_type_video_guide = gr.Dropdown( guide_preprocessing_choices, - value=filter_letters(video_prompt_type_value, "PDESLCMUVB", guide_preprocessing.get("default", "") ), + value=filter_letters(video_prompt_type_value, all_guide_processes, guide_preprocessing.get("default", "") ), label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, ) any_control_video = True @@ -7590,8 +7599,8 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl if image_guide_value is None: image_mask_guide_value = None else: - image_mask_value = rgb_bw_to_rgba_mask(image_mask_value) - image_mask_guide_value = { "background" : image_guide_value, "composite" : None, "layers": [image_mask_value] } + image_mask_guide_value = { "background" : image_guide_value, "composite" : None} + image_mask_guide_value["layers"] = [] if image_mask_value is None else [rgb_bw_to_rgba_mask(image_mask_value)] image_mask_guide = gr.ImageEditor( label="Control Image to be Inpainted" if image_mode_value == 2 else "Control Image and Mask", From b44fcc21df1cb0157381433d66c7fb1adeb77735 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 24 Sep 2025 23:45:44 +0200 Subject: [PATCH 036/155] restored original gallery height --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index f92cf3334..18ccd3441 100644 --- a/wgp.py +++ b/wgp.py @@ -7396,7 +7396,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non any_start_image = any_end_image = any_reference_image = any_image_mask = False v2i_switch_supported = (vace or t2v or standin) and not image_outputs ti2v_2_2 = base_model_type in ["ti2v_2_2"] - gallery_height = 550 + gallery_height = 350 def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ): with gr.Row(visible = visible) as gallery_row: gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode ) From 654bb0483bf7d1bf32ba0c56e1392ddd761a014f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 25 Sep 2025 00:50:03 +0200 Subject: [PATCH 037/155] fixed missing audio input for Hunyuan Avatar and Custom Audio --- models/hyvideo/hunyuan_handler.py | 10 +++++++++- wgp.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index 435234102..ebe07d26d 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -173,8 +173,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): video_prompt_type = video_prompt_type.replace("M","") ui_defaults["video_prompt_type"] = video_prompt_type + if settings_version < 2.36: + if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]: + audio_prompt_type= ui_defaults["audio_prompt_type"] + if "A" not in audio_prompt_type: + audio_prompt_type += "A" + ui_defaults["audio_prompt_type"] = audio_prompt_type + - pass @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): @@ -197,6 +203,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "guidance_scale": 7.5, "flow_shift": 13, "video_prompt_type": "I", + "audio_prompt_type": "A", }) elif base_model_type in ["hunyuan_custom_edit"]: ui_defaults.update({ @@ -213,4 +220,5 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "skip_steps_start_step_perc": 25, "video_length": 129, "video_prompt_type": "KI", + "audio_prompt_type": "A", }) diff --git a/wgp.py b/wgp.py index 18ccd3441..e7264d295 100644 --- a/wgp.py +++ b/wgp.py @@ -64,7 +64,7 @@ target_mmgp_version = "3.6.0" WanGP_version = "8.72" -settings_version = 2.35 +settings_version = 2.36 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None From ee0bb89ee94618eac574f1290dd45532b3b59220 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 25 Sep 2025 02:16:57 +0200 Subject: [PATCH 038/155] Added Qwen Preview mode --- README.md | 3 +- models/qwen/pipeline_qwenimage.py | 5 +-- models/qwen/qwen_handler.py | 7 ++++ models/wan/any2video.py | 54 ++++++++++++++----------------- models/wan/wan_handler.py | 4 ++- shared/RGB_factors.py | 4 +-- 6 files changed, 41 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 30a97e66b..9b8f70d45 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### September 24 2025: WanGP v8.72 - Here Are ~~Two~~Three New Contenders in the Vace Arena ! +### September 25 2025: WanGP v8.73 - Here Are ~~Two~~Three New Contenders in the Vace Arena ! So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. @@ -34,6 +34,7 @@ Also because I wanted to spoil you: *Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora *Update 8.72*: shadow drop of Qwen Edit Plus +*Update 8.73*: Qwen Preview & InfiniteTalk Start image ### September 15 2025: WanGP v8.6 - Attack of the Clones diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 85934b79f..134cc51b5 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -971,8 +971,9 @@ def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): latents = latents.to(latents_dtype) if callback is not None: - # preview = unpack_latent(img).transpose(0,1) - callback(i, None, False) + preview = self._unpack_latents(latents, height, width, self.vae_scale_factor) + preview = preview.squeeze(0) + callback(i, preview, False) self._current_timestep = None diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 4fcaa3ba8..99864a5a7 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -129,6 +129,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "model_mode" : 0, }) + @staticmethod def validate_generative_settings(base_model_type, model_def, inputs): if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]: model_mode = inputs["model_mode"] @@ -141,3 +142,9 @@ def validate_generative_settings(base_model_type, model_def, inputs): gr.Info("Denoising Strength will be ignored while using Lora Inpainting") if outpainting_dims is not None and model_mode == 0 : return "Outpainting is not supported with Masked Denoising " + + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("qwen") + return latent_rgb_factors, latent_rgb_factors_bias diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 41d6d6316..6b4ae629d 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -443,38 +443,32 @@ def generate(self, # image2video if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False - if image_start is None: - if infinitetalk: - new_shot = "Q" in video_prompt_type - if input_frames is not None: - image_ref = input_frames[:, 0] - else: - if input_ref_images is None: - if pre_video_frame is None: raise Exception("Missing Reference Image") - input_ref_images, new_shot = [pre_video_frame], False - new_shot = new_shot and window_no <= len(input_ref_images) - image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) - if new_shot or input_video is None: - input_video = image_ref.unsqueeze(1) - else: - color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot - _ , preframes_count, height, width = input_video.shape - input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) - if infinitetalk: - image_start = image_ref.to(input_video) - control_pre_frames_count = 1 - control_video = image_start.unsqueeze(1) + if infinitetalk: + new_shot = "Q" in video_prompt_type + if input_frames is not None: + image_ref = input_frames[:, 0] else: - image_start = input_video[:, -1] - control_pre_frames_count = preframes_count - control_video = input_video - - color_reference_frame = image_start.unsqueeze(1).clone() + if input_ref_images is None: + if pre_video_frame is None: raise Exception("Missing Reference Image") + input_ref_images, new_shot = [pre_video_frame], False + new_shot = new_shot and window_no <= len(input_ref_images) + image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) + if new_shot or input_video is None: + input_video = image_ref.unsqueeze(1) + else: + color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot + _ , preframes_count, height, width = input_video.shape + input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) + if infinitetalk: + image_start = image_ref.to(input_video) + control_pre_frames_count = 1 + control_video = image_start.unsqueeze(1) else: - preframes_count = control_pre_frames_count = 1 - height, width = image_start.shape[1:] - control_video = image_start.unsqueeze(1).to(self.device) - color_reference_frame = control_video.clone() + image_start = input_video[:, -1] + control_pre_frames_count = preframes_count + control_video = input_video + + color_reference_frame = image_start.unsqueeze(1).clone() any_end_frame = image_end is not None add_frames_for_end_image = any_end_frame and model_type == "i2v" diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 12ddfed45..574c990a0 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -245,8 +245,10 @@ def query_model_def(base_model_type, model_def): "visible" : False, } - if vace_class or base_model_type in ["infinitetalk", "animate"]: + if vace_class or base_model_type in ["animate"]: image_prompt_types_allowed = "TVL" + elif base_model_type in ["infinitetalk"]: + image_prompt_types_allowed = "TSVL" elif base_model_type in ["ti2v_2_2"]: image_prompt_types_allowed = "TSVL" elif base_model_type in ["lucy_edit"]: diff --git a/shared/RGB_factors.py b/shared/RGB_factors.py index 6e865fa7b..8a870b4b6 100644 --- a/shared/RGB_factors.py +++ b/shared/RGB_factors.py @@ -1,6 +1,6 @@ # thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py) def get_rgb_factors(model_family, model_type = None): - if model_family == "wan": + if model_family in ["wan", "qwen"]: if model_type =="ti2v_2_2": latent_channels = 48 latent_dimensions = 3 @@ -261,7 +261,7 @@ def get_rgb_factors(model_family, model_type = None): [ 0.0249, -0.0469, -0.1703] ] - latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] else: latent_rgb_factors_bias = latent_rgb_factors = None return latent_rgb_factors, latent_rgb_factors_bias \ No newline at end of file From 5cfedca74480099e11c8a5fc6a69c1e8abe701cc Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 25 Sep 2025 02:17:59 +0200 Subject: [PATCH 039/155] Added Qwen Preview mode --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index e7264d295..5267ca35b 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.72" +WanGP_version = "8.73" settings_version = 2.36 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None From 14d68bbc9152f35c00a55750cc15327680a03dd5 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 25 Sep 2025 09:40:51 +0200 Subject: [PATCH 040/155] fixed double Vace controlnets with no mask --- models/wan/diffusion_forcing copy.py | 479 ------------------ models/wan/text2video fuse attempt.py | 698 -------------------------- wgp.py | 11 +- 3 files changed, 7 insertions(+), 1181 deletions(-) delete mode 100644 models/wan/diffusion_forcing copy.py delete mode 100644 models/wan/text2video fuse attempt.py diff --git a/models/wan/diffusion_forcing copy.py b/models/wan/diffusion_forcing copy.py deleted file mode 100644 index 753fd4561..000000000 --- a/models/wan/diffusion_forcing copy.py +++ /dev/null @@ -1,479 +0,0 @@ -import math -import os -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union -import logging -import numpy as np -import torch -from diffusers.image_processor import PipelineImageInput -from diffusers.utils.torch_utils import randn_tensor -from diffusers.video_processor import VideoProcessor -from tqdm import tqdm -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - -class DTT2V: - - - def __init__( - self, - config, - checkpoint_dir, - rank=0, - model_filename = None, - text_encoder_filename = None, - quantizeTransformer = False, - dtype = torch.bfloat16, - ): - self.device = torch.device(f"cuda") - self.config = config - self.rank = rank - self.dtype = dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn= None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - - - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanModel from {model_filename}") - from mmgp import offload - - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json") - # offload.load_model_data(self.model, "recam.ckpt") - # self.model.cpu() - # offload.save_model(self.model, "recam.safetensors") - if self.dtype == torch.float16 and not "fp16" in model_filename: - self.model.to(self.dtype) - # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) - if self.dtype == torch.float16: - self.vae.model.to(self.dtype) - self.model.eval().requires_grad_(False) - - self.scheduler = FlowUniPCMultistepScheduler() - - @property - def do_classifier_free_guidance(self) -> bool: - return self._guidance_scale > 1 - - def encode_image( - self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # prefix_video - prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1) - prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) - if prefix_video.dtype == torch.uint8: - prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 - prefix_video = prefix_video.to(self.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] - return prefix_video, predix_video_latent_length - - def prepare_latents( - self, - shape: Tuple[int], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - ) -> torch.Tensor: - return randn_tensor(shape, generator, device=device, dtype=dtype) - - def generate_timestep_matrix( - self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: - step_matrix, step_index = [], [] - update_mask, valid_interval = [], [] - num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size - if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block - assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) - step_template = torch.cat( - [ - torch.tensor([999], dtype=torch.int64, device=step_template.device), - step_template.long(), - torch.tensor([0], dtype=torch.int64, device=step_template.device), - ] - ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) - if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations - - while torch.all(pre_row >= (num_iterations - 1)) == False: - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): - if i == 0 or pre_row[i - 1] >= ( - num_iterations - 1 - ): # the first frame or the last frame is completely denoised - new_row[i] = pre_row[i] + 1 - else: - new_row[i] = new_row[i - 1] - ar_step - new_row = new_row.clamp(0, num_iterations) - - update_mask.append( - (new_row != pre_row) & (new_row != num_iterations) - ) # False: no need to update, True: need to update - step_index.append(new_row) - step_matrix.append(step_template[new_row]) - pre_row = new_row - - # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block - if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) - update_mask = update_mask[0] - update_mask_idx = idx_sequence[update_mask] - last_update_idx = update_mask_idx[-1].item() - terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): - for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: - terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) - - step_update_mask = torch.stack(update_mask, dim=0) - step_index = torch.stack(step_index, dim=0) - step_matrix = torch.stack(step_matrix, dim=0) - - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] - - return step_matrix, step_index, step_update_mask, valid_interval - - @torch.no_grad() - def generate( - self, - prompt: Union[str, List[str]], - negative_prompt: Union[str, List[str]] = "", - image: PipelineImageInput = None, - height: int = 480, - width: int = 832, - num_frames: int = 97, - num_inference_steps: int = 50, - shift: float = 1.0, - guidance_scale: float = 5.0, - seed: float = 0.0, - overlap_history: int = 17, - addnoise_condition: int = 0, - base_num_frames: int = 97, - ar_step: int = 5, - causal_block_size: int = 1, - causal_attention: bool = False, - fps: int = 24, - VAE_tile_size = 0, - joint_pass = False, - callback = None, - ): - generator = torch.Generator(device=self.device) - generator.manual_seed(seed) - # if base_num_frames > base_num_frames: - # causal_block_size = 0 - self._guidance_scale = guidance_scale - - i2v_extra_kwrags = {} - prefix_video = None - predix_video_latent_length = 0 - if image: - frame_width, frame_height = image.size - scale = min(height / frame_height, width / frame_width) - height = (int(frame_height * scale) // 16) * 16 - width = (int(frame_width * scale) // 16) * 16 - - prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size) - - latent_length = (num_frames - 1) // 4 + 1 - latent_height = height // 8 - latent_width = width // 8 - - prompt_embeds = self.text_encoder([prompt], self.device) - prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds] - if self.do_classifier_free_guidance: - negative_prompt_embeds = self.text_encoder([negative_prompt], self.device) - negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds] - - - - self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - init_timesteps = self.scheduler.timesteps - fps_embeds = [fps] * prompt_embeds[0].shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - transformer_dtype = self.dtype - # with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad(): - if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: - # short video generation - latent_shape = [16, latent_length, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=torch.float32, device=self.device, generator=generator - ) - latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size - ) - sample_schedulers = [] - for _ in range(latent_length): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length - - if callback != None: - callback(-1, None, True) - - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - # "causal_block_size" : causal_block_size, - "callback" : callback, - "pipeline" : self - } - kwrags.update(i2v_extra_kwrags) - - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=prompt_embeds, - context2=negative_prompt_embeds, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=negative_prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0], False) - - x0 = latents[0].unsqueeze(0) - videos = self.vae.decode(x0, tile_size= VAE_tile_size) - videos = (videos / 2 + 0.5).clamp(0, 1) - videos = [video for video in videos] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - videos = [video.cpu().numpy().astype(np.uint8) for video in videos] - return videos - else: - # long video generation - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - overlap_history_frames = (overlap_history - 1) // 4 + 1 - n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - print(f"n_iter:{n_iter}") - output_video = None - for i in range(n_iter): - if output_video is not None: # i !=0 - prefix_video = output_video[:, -overlap_history:].to(self.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = latent_length - finished_frame_num - base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - else: # i == 0 - base_num_frames_iter = base_num_frames - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=torch.float32, device=self.device, generator=generator - ) - latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - init_timesteps, - base_num_frames_iter, - ar_step, - predix_video_latent_length, - causal_block_size, - ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter - if callback != None: - callback(-1, None, True) - - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - "causal_block_size" : causal_block_size, - "causal_attention" : causal_attention, - "callback" : callback, - "pipeline" : self - } - kwrags.update(i2v_extra_kwrags) - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=prompt_embeds, - context2=negative_prompt_embeds, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=negative_prompt_embeds, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0].squeeze(0), False) - - x0 = latents[0].unsqueeze(0) - videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] - if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w - else: - output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 - ) # c, f, h, w - return output_video diff --git a/models/wan/text2video fuse attempt.py b/models/wan/text2video fuse attempt.py deleted file mode 100644 index 8af94583d..000000000 --- a/models/wan/text2video fuse attempt.py +++ /dev/null @@ -1,698 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import gc -import logging -import math -import os -import random -import sys -import types -from contextlib import contextmanager -from functools import partial -from mmgp import offload -import torch -import torch.nn as nn -import torch.cuda.amp as amp -import torch.distributed as dist -from tqdm import tqdm -from PIL import Image -import torchvision.transforms.functional as TF -import torch.nn.functional as F -from .distributed.fsdp import shard_model -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.vace_preprocessor import VaceVideoProcessor - - -def optimized_scale(positive_flat, negative_flat): - - # Calculate dot production - dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - - # Squared norm of uncondition - squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 - - # st_star = v_cond^T * v_uncond / ||v_uncond||^2 - st_star = dot_product / squared_norm - - return st_star - - -class WanT2V: - - def __init__( - self, - config, - checkpoint_dir, - rank=0, - model_filename = None, - text_encoder_filename = None, - quantizeTransformer = False, - dtype = torch.bfloat16 - ): - self.device = torch.device(f"cuda") - self.config = config - self.rank = rank - self.dtype = dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn= None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - - - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanModel from {model_filename}") - from mmgp import offload - - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) - # offload.load_model_data(self.model, "recam.ckpt") - # self.model.cpu() - # offload.save_model(self.model, "recam.safetensors") - if self.dtype == torch.float16 and not "fp16" in model_filename: - self.model.to(self.dtype) - # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) - if self.dtype == torch.float16: - self.vae.model.to(self.dtype) - self.model.eval().requires_grad_(False) - - - self.sample_neg_prompt = config.sample_neg_prompt - - if "Vace" in model_filename: - self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), - min_area=480*832, - max_area=480*832, - min_fps=config.sample_fps, - max_fps=config.sample_fps, - zero_start=True, - seq_len=32760, - keep_last=True) - - self.adapt_vace_model() - - self.scheduler = FlowUniPCMultistepScheduler() - - def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) - - if masks is None: - latents = self.vae.encode(frames, tile_size = tile_size) - else: - inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] - reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] - inactive = self.vae.encode(inactive, tile_size = tile_size) - reactive = self.vae.encode(reactive, tile_size = tile_size) - latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] - - cat_latents = [] - for latent, refs in zip(latents, ref_images): - if refs is not None: - if masks is None: - ref_latent = self.vae.encode(refs, tile_size = tile_size) - else: - ref_latent = self.vae.encode(refs, tile_size = tile_size) - ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] - assert all([x.shape[1] == 1 for x in ref_latent]) - latent = torch.cat([*ref_latent, latent], dim=1) - cat_latents.append(latent) - return cat_latents - - def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - - result_masks = [] - for mask, refs in zip(masks, ref_images): - c, depth, height, width = mask.shape - new_depth = int((depth + 3) // self.vae_stride[0]) - height = 2 * (int(height) // (self.vae_stride[1] * 2)) - width = 2 * (int(width) // (self.vae_stride[2] * 2)) - - # reshape - mask = mask[0, :, :, :] - mask = mask.view( - depth, height, self.vae_stride[1], width, self.vae_stride[1] - ) # depth, height, 8, width, 8 - mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width - mask = mask.reshape( - self.vae_stride[1] * self.vae_stride[2], depth, height, width - ) # 8*8, depth, height, width - - # interpolation - mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) - - if refs is not None: - length = len(refs) - mask_pad = torch.zeros_like(mask[:, :length, :, :]) - mask = torch.cat((mask_pad, mask), dim=1) - result_masks.append(mask) - return result_masks - - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): - image_sizes = [] - trim_video = len(keep_frames) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - if sub_src_mask is not None and sub_src_video is not None: - src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) - # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - src_video[i] = src_video[i].to(device) - src_mask[i] = src_mask[i].to(device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) - src_video[i] = src_video[i].to(device) - src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_frames): - if not keep: - src_video[i][:, k:k+1] = 0 - src_mask[i][:, k:k+1] = 1 - - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None: - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - if ref_img.shape[-2:] != image_size: - canvas_height, canvas_width = image_size - ref_height, ref_width = ref_img.shape[-2:] - white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] - scale = min(canvas_height / ref_height, canvas_width / ref_width) - new_height = int(ref_height * scale) - new_width = int(ref_width * scale) - resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) - top = (canvas_height - new_height) // 2 - left = (canvas_width - new_width) // 2 - white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image - ref_img = white_canvas - src_ref_images[i][j] = ref_img.to(device) - return src_video, src_mask, src_ref_images - - def decode_latent(self, zs, ref_images=None, tile_size= 0 ): - if ref_images is None: - ref_images = [None] * len(zs) - else: - assert len(zs) == len(ref_images) - - trimed_zs = [] - for z, refs in zip(zs, ref_images): - if refs is not None: - z = z[:, len(refs):, :, :] - trimed_zs.append(z) - - return self.vae.decode(trimed_zs, tile_size= tile_size) - - def generate_timestep_matrix( - self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: - step_matrix, step_index = [], [] - update_mask, valid_interval = [], [] - num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size - if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block - assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) - step_template = torch.cat( - [ - torch.tensor([999], dtype=torch.int64, device=step_template.device), - step_template.long(), - torch.tensor([0], dtype=torch.int64, device=step_template.device), - ] - ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) - if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations - - while torch.all(pre_row >= (num_iterations - 1)) == False: - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): - if i == 0 or pre_row[i - 1] >= ( - num_iterations - 1 - ): # the first frame or the last frame is completely denoised - new_row[i] = pre_row[i] + 1 - else: - new_row[i] = new_row[i - 1] - ar_step - new_row = new_row.clamp(0, num_iterations) - - update_mask.append( - (new_row != pre_row) & (new_row != num_iterations) - ) # False: no need to update, True: need to update - step_index.append(new_row) - step_matrix.append(step_template[new_row]) - pre_row = new_row - - # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block - if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) - update_mask = update_mask[0] - update_mask_idx = idx_sequence[update_mask] - last_update_idx = update_mask_idx[-1].item() - terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): - for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: - terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) - - step_update_mask = torch.stack(update_mask, dim=0) - step_index = torch.stack(step_index, dim=0) - step_matrix = torch.stack(step_matrix, dim=0) - - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] - - return step_matrix, step_index, step_update_mask, valid_interval - - def generate(self, - input_prompt, - input_frames= None, - input_masks = None, - input_ref_images = None, - source_video=None, - target_camera=None, - context_scale=1.0, - size=(1280, 720), - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True, - callback = None, - enable_RIFLEx = None, - VAE_tile_size = 0, - joint_pass = False, - slg_layers = None, - slg_start = 0.0, - slg_end = 1.0, - cfg_star_switch = True, - cfg_zero_step = 5, - ): - r""" - Generates video frames from text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation - size (tupele[`int`], *optional*, defaults to (1280,720)): - Controls video resolution, (width,height). - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed. - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from size) - - W: Frame width from size) - """ - # preprocess - - if n_prompt == "": - n_prompt = self.sample_neg_prompt - seed = seed if seed >= 0 else random.randint(0, sys.maxsize) - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed) - - frame_num = max(17, frame_num) # must match causal_block_size for value of 5 - frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) - num_frames = frame_num - addnoise_condition = 20 - causal_attention = True - fps = 16 - ar_step = 5 - - - - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if target_camera != None: - size = (source_video.shape[2], source_video.shape[1]) - source_video = source_video.to(dtype=self.dtype , device=self.device) - source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) - source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) - del source_video - # Process target camera (recammaster) - from wan.utils.cammmaster_tools import get_camera_embedding - cam_emb = get_camera_embedding(target_camera) - cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) - - if input_frames != None: - # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] - - z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) - m0 = self.vace_encode_masks(input_masks, input_ref_images) - z = self.vace_latent(z0, m0) - - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - else: - F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1]) - - context = [u.to(self.dtype) for u in context] - context_null = [u.to(self.dtype) for u in context_null] - - noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ] - - # evaluation mode - - # if sample_solver == 'unipc': - # sample_scheduler = FlowUniPCMultistepScheduler( - # num_train_timesteps=self.num_train_timesteps, - # shift=1, - # use_dynamic_shifting=False) - # sample_scheduler.set_timesteps( - # sampling_steps, device=self.device, shift=shift) - # timesteps = sample_scheduler.timesteps - # elif sample_solver == 'dpm++': - # sample_scheduler = FlowDPMSolverMultistepScheduler( - # num_train_timesteps=self.num_train_timesteps, - # shift=1, - # use_dynamic_shifting=False) - # sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - # timesteps, _ = retrieve_timesteps( - # sample_scheduler, - # device=self.device, - # sigmas=sampling_sigmas) - # else: - # raise NotImplementedError("Unsupported solver.") - - # sample videos - latents = noise - del noise - batch_size =len(latents) - if target_camera != None: - shape = list(latents[0].shape[1:]) - shape[0] *= 2 - freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) - else: - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) - # arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback} - # arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} - # arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} - - i2v_extra_kwrags = {} - - if target_camera != None: - recam_dict = {'cam_emb': cam_emb} - i2v_extra_kwrags.update(recam_dict) - - if input_frames != None: - vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} - i2v_extra_kwrags.update(vace_dict) - - - latent_length = (num_frames - 1) // 4 + 1 - latent_height = height // 8 - latent_width = width // 8 - if ar_step == 0: - causal_block_size = 1 - fps_embeds = [fps] #* prompt_embeds[0].shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - - self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) - init_timesteps = self.scheduler.timesteps - base_num_frames_iter = latent_length - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] - - prefix_video = None - predix_video_latent_length = 0 - - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - init_timesteps, - base_num_frames_iter, - ar_step, - predix_video_latent_length, - causal_block_size, - ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter - - updated_num_steps= len(step_matrix) - - if callback != None: - callback(-1, None, True, override_num_inference_steps = updated_num_steps) - if self.model.enable_teacache: - self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) - # if callback != None: - # callback(-1, None, True) - - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - "causal_block_size" : causal_block_size, - "causal_attention" : causal_attention, - "callback" : callback, - "pipeline" : self, - "current_step" : i, - } - kwrags.update(i2v_extra_kwrags) - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=context, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=context, - context2=context_null, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=context, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=context_null, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=seed_g, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0].squeeze(0), False) - - # for i, t in enumerate(tqdm(timesteps)): - # if target_camera != None: - # latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] - # else: - # latent_model_input = latents - # slg_layers_local = None - # if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): - # slg_layers_local = slg_layers - # timestep = [t] - # offload.set_step_no_for_lora(self.model, i) - # timestep = torch.stack(timestep) - - # if joint_pass: - # noise_pred_cond, noise_pred_uncond = self.model( - # latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) - # if self._interrupt: - # return None - # else: - # noise_pred_cond = self.model( - # latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] - # if self._interrupt: - # return None - # noise_pred_uncond = self.model( - # latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0] - # if self._interrupt: - # return None - - # # del latent_model_input - - # # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ - # noise_pred_text = noise_pred_cond - # if cfg_star_switch: - # positive_flat = noise_pred_text.view(batch_size, -1) - # negative_flat = noise_pred_uncond.view(batch_size, -1) - - # alpha = optimized_scale(positive_flat,negative_flat) - # alpha = alpha.view(batch_size, 1, 1, 1) - - # if (i <= cfg_zero_step): - # noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... - # else: - # noise_pred_uncond *= alpha - # noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) - # del noise_pred_uncond - - # temp_x0 = sample_scheduler.step( - # noise_pred[:, :target_shape[1]].unsqueeze(0), - # t, - # latents[0].unsqueeze(0), - # return_dict=False, - # generator=seed_g)[0] - # latents = [temp_x0.squeeze(0)] - # del temp_x0 - - # if callback is not None: - # callback(i, latents[0], False) - - x0 = latents - - if input_frames == None: - videos = self.vae.decode(x0, VAE_tile_size) - else: - videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) - - del latents - del sample_scheduler - - return videos[0] if self.rank == 0 else None - - def adapt_vace_model(self): - model = self.model - modules_dict= { k: m for k, m in model.named_modules()} - for model_layer, vace_layer in model.vace_layers_mapping.items(): - module = modules_dict[f"vace_blocks.{vace_layer}"] - target = modules_dict[f"blocks.{model_layer}"] - setattr(target, "vace", module ) - delattr(model, "vace_blocks") - - \ No newline at end of file diff --git a/wgp.py b/wgp.py index 5267ca35b..d1eae3b58 100644 --- a/wgp.py +++ b/wgp.py @@ -5075,7 +5075,7 @@ def remove_temp_filenames(temp_filenames_list): any_guide_padding = model_def.get("pad_guide_video", False) from shared.utils.utils import prepare_video_guide_and_mask src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), - [video_mask_processed] + ([] if video_mask_processed2 is None else [video_mask_processed2]), + [video_mask_processed] + ([] if video_guide_processed2 is None else [video_mask_processed2]), None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, image_size, current_video_length, latent_size, any_mask, any_guide_padding, guide_inpaint_color, @@ -5097,9 +5097,12 @@ def remove_temp_filenames(temp_filenames_list): src_faces = src_faces[:, :src_video.shape[1]] if video_guide is not None or len(frames_to_inject_parsed) > 0: if args.save_masks: - if src_video is not None: save_video( src_video, "masked_frames.mp4", fps) - if src_video2 is not None: save_video( src_video2, "masked_frames2.mp4", fps) - if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + if src_video is not None: + save_video( src_video, "masked_frames.mp4", fps) + if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + if src_video2 is not None: + save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1)) if video_guide is not None: preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) From ba764e8d60552fee29190898028e46ee67a61fd5 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 25 Sep 2025 23:21:51 +0200 Subject: [PATCH 041/155] more spoil --- README.md | 10 ++++- defaults/animate.json | 6 ++- models/flux/sampling.py | 4 +- models/qwen/pipeline_qwenimage.py | 4 +- models/wan/any2video.py | 51 ++++++++++++++++------- models/wan/modules/model.py | 15 +++---- models/wan/wan_handler.py | 32 +++++++++++---- shared/utils/loras_mutipliers.py | 5 ++- wgp.py | 67 ++++++++++++++++++------------- 9 files changed, 129 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 9b8f70d45..d9be0fdf7 100644 --- a/README.md +++ b/README.md @@ -20,21 +20,27 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### September 25 2025: WanGP v8.73 - Here Are ~~Two~~Three New Contenders in the Vace Arena ! +### September 25 2025: WanGP v8.73 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release -So in today's release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: +So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion. +With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise) + - **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. Also because I wanted to spoil you: - **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version. +- **T2V Video 2 Video Masking**: ever wanted to apply a Lora, a process (for instance Upsampling) or a Text Prompt on only a (moving) part of a Source Video. Look no further, I have added *Masked Video 2 Video* (which works also in image2image) in the *Text 2 Video* models. As usual you just need to use *Matanyone* to creatre the mask. + + *Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora *Update 8.72*: shadow drop of Qwen Edit Plus *Update 8.73*: Qwen Preview & InfiniteTalk Start image +*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video ### September 15 2025: WanGP v8.6 - Attack of the Clones diff --git a/defaults/animate.json b/defaults/animate.json index fee26d76a..bf45f4a92 100644 --- a/defaults/animate.json +++ b/defaults/animate.json @@ -2,12 +2,16 @@ "model": { "name": "Wan2.2 Animate", "architecture": "animate", - "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'animation' or 'replacement' mode.", + "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'Animation' or 'Replacement' mode. Sliding Window of 81 frames at least are recommeded to obtain the best Style continuity.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors" ], + "preload_URLs" : + [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors" + ], "group": "wan2_2" } } \ No newline at end of file diff --git a/models/flux/sampling.py b/models/flux/sampling.py index 92ee59099..939543c64 100644 --- a/models/flux/sampling.py +++ b/models/flux/sampling.py @@ -292,7 +292,7 @@ def denoise( callback(-1, None, True) original_image_latents = None if img_cond_seq is None else img_cond_seq.clone() - + original_timesteps = timesteps morph, first_step = False, 0 if img_msk_latents is not None: randn = torch.randn_like(original_image_latents) @@ -309,7 +309,7 @@ def denoise( updated_num_steps= len(timesteps) -1 if callback != None: from shared.utils.loras_mutipliers import update_loras_slists - update_loras_slists(model, loras_slists, updated_num_steps) + update_loras_slists(model, loras_slists, len(original_timesteps)) callback(-1, None, True, override_num_inference_steps = updated_num_steps) from mmgp import offload # this is ignored for schnell diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 134cc51b5..02fd473b4 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -825,7 +825,7 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - + original_timesteps = timesteps # handle guidance if self.transformer.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) @@ -860,7 +860,7 @@ def __call__( updated_num_steps= len(timesteps) if callback != None: from shared.utils.loras_mutipliers import update_loras_slists - update_loras_slists(self.transformer, loras_slists, updated_num_steps) + update_loras_slists(self.transformer, loras_slists, len(original_timesteps)) callback(-1, None, True, override_num_inference_steps = updated_num_steps) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 6b4ae629d..285340d39 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -34,6 +34,7 @@ from shared.utils.basic_flowmatch import FlowMatchScheduler from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask +from shared.utils.audio_video import save_video from mmgp import safetensors2 from shared.utils.audio_video import save_video @@ -434,7 +435,7 @@ def generate(self, start_step_no = 0 ref_images_count = 0 trim_frames = 0 - extended_overlapped_latents = clip_image_start = clip_image_end = None + extended_overlapped_latents = clip_image_start = clip_image_end = image_mask_latents = None no_noise_latents_injection = infinitetalk timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 @@ -585,7 +586,7 @@ def generate(self, kwargs['cam_emb'] = cam_emb # Video 2 Video - if denoising_strength < 1. and input_frames != None: + if "G" in video_prompt_type and input_frames != None: height, width = input_frames.shape[-2:] source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) injection_denoising_step = 0 @@ -616,6 +617,14 @@ def generate(self, if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] injection_denoising_step = 0 + if input_masks is not None and not "U" in video_prompt_type: + image_mask_latents = torch.nn.functional.interpolate(input_masks, size= source_latents.shape[-2:], mode="nearest").unsqueeze(0) + if image_mask_latents.shape[2] !=1: + image_mask_latents = torch.cat([ image_mask_latents[:,:, :1], torch.nn.functional.interpolate(image_mask_latents, size= (source_latents.shape[-3]-1, *source_latents.shape[-2:]), mode="nearest") ], dim=2) + image_mask_latents = torch.where(image_mask_latents>=0.5, 1., 0. )[:1].to(self.device) + # save_video(image_mask_latents.squeeze(0), "mama.mp4", value_range=(0,1) ) + # image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) + # Phantom if phantom: lat_input_ref_images_neg = None @@ -737,7 +746,7 @@ def generate(self, denoising_extra = "" from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps - phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) + phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(original_timesteps,guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) if len(phases_description) > 0: set_header_text(phases_description) guidance_switch_done = guidance_switch2_done = False if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}" @@ -748,8 +757,8 @@ def update_guidance(step_no, t, guide_scale, new_guide_scale, guidance_switch_do denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}" callback(step_no-1, denoising_extra = denoising_extra) return guide_scale, guidance_switch_done, trans, denoising_extra - update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) - if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + update_loras_slists(self.model, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + if self.model2 is not None: update_loras_slists(self.model2, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra) def clear(): @@ -762,6 +771,7 @@ def clear(): scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} # b, c, lat_f, lat_h, lat_w latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) + if "G" in video_prompt_type: randn = latents if apg_switch != 0: apg_momentum = -0.75 apg_norm_threshold = 55 @@ -784,23 +794,21 @@ def clear(): timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device) timestep[:source_latents.shape[2]] = 0 - kwargs.update({"t": timestep, "current_step": start_step_no + i}) + kwargs.update({"t": timestep, "current_step_no": i, "real_step_no": start_step_no + i }) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None if denoising_strength < 1 and i <= injection_denoising_step: sigma = t / 1000 - noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: - new_latents = latents.clone() - new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents + noisy_image = latents.clone() + noisy_image[:,:, :source_latents.shape[2] ] = randn[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents for latent_no, keep_latent in enumerate(latent_keep_frames): if not keep_latent: - new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] - latents = new_latents - new_latents = None + noisy_image[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] + latents = noisy_image + noisy_image = None else: - latents = noise * sigma + (1 - sigma) * source_latents - noise = None + latents = randn * sigma + (1 - sigma) * source_latents if extended_overlapped_latents != None: if no_noise_latents_injection: @@ -940,6 +948,13 @@ def clear(): latents, **scheduler_kwargs)[0] + + if image_mask_latents is not None: + sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000 + noisy_image = randn * sigma + (1 - sigma) * source_latents + latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents + + if callback is not None: latents_preview = latents if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] @@ -999,3 +1014,11 @@ def adapt_animate_model(self, model): setattr(target, "face_adapter_fuser_blocks", module ) delattr(model, "face_adapter") + def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs): + if base_model_type == "animate": + if "1" in video_prompt_type: + preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") + if len(preloadURLs) > 0: + return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] + return [], [] + diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index c98087b5d..cd024704c 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -1220,7 +1220,8 @@ def forward( y=None, freqs = None, pipeline = None, - current_step = 0, + current_step_no = 0, + real_step_no = 0, x_id= 0, max_steps = 0, slg_layers=None, @@ -1310,9 +1311,9 @@ def forward( del causal_mask offload.shared_state["embed_sizes"] = grid_sizes - offload.shared_state["step_no"] = current_step + offload.shared_state["step_no"] = real_step_no offload.shared_state["max_steps"] = max_steps - if current_step == 0 and x_id == 0: clear_caches() + if current_step_no == 0 and x_id == 0: clear_caches() # arguments kwargs = dict( @@ -1336,7 +1337,7 @@ def forward( if standin_ref is not None: standin_cache_enabled = False kwargs["standin_phase"] = 2 - if current_step == 0 or not standin_cache_enabled : + if current_step_no == 0 or not standin_cache_enabled : standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2) standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) ) standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype) @@ -1401,7 +1402,7 @@ def forward( skips_steps_cache = self.cache if skips_steps_cache != None: if skips_steps_cache.cache_type == "mag": - if current_step <= skips_steps_cache.start_step: + if real_step_no <= skips_steps_cache.start_step: should_calc = True elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all assert len(x_list) == 1 @@ -1410,7 +1411,7 @@ def forward( x_should_calc = [] for i in range(1 if skips_steps_cache.one_for_all else len(x_list)): cur_x_id = i if joint_pass else x_id - cur_mag_ratio = skips_steps_cache.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list + cur_mag_ratio = skips_steps_cache.mag_ratios[real_step_no * 2 + cur_x_id] # conditional and unconditional in one list skips_steps_cache.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step skips_steps_cache.accumulated_steps[cur_x_id] += 1 # skip steps plus 1 cur_skip_err = np.abs(1-skips_steps_cache.accumulated_ratio[cur_x_id]) # skip error of current steps @@ -1430,7 +1431,7 @@ def forward( if x_id != 0: should_calc = skips_steps_cache.should_calc else: - if current_step <= skips_steps_cache.start_step or current_step == skips_steps_cache.num_steps-1: + if real_step_no <= skips_steps_cache.start_step or real_step_no == skips_steps_cache.num_steps-1: should_calc = True skips_steps_cache.accumulated_rel_l1_distance = 0 else: diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 574c990a0..c5875445b 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -124,12 +124,20 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["t2v"]: extra_model_def["guide_custom_choices"] = { - "choices":[("Use Text Prompt Only", ""),("Video to Video guided by Text Prompt", "GUV")], + "choices":[("Use Text Prompt Only", ""), + ("Video to Video guided by Text Prompt", "GUV"), + ("Video to Video guided by Text Prompt and Restricted to the Area of the Video Mask", "GVA")], "default": "", - "letters_filter": "GUV", + "letters_filter": "GUVA", "label": "Video to Video" } + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A"], + "visible": False + } + + if base_model_type in ["infinitetalk"]: extra_model_def["no_background_removal"] = True extra_model_def["all_image_refs_are_background_ref"] = True @@ -160,14 +168,22 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["animate"]: extra_model_def["guide_custom_choices"] = { "choices":[ - ("Animate Person in Reference Image using Motion of Person in Control Video", "PVBXAKI"), - ("Replace Person in Control Video Person in Reference Image", "PVBAI"), + ("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"), + ("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"), + ("Replace Person in Control Video by Person in Reference Image", "PVBAI"), + ("Replace Person in Control Video by Person in Reference Image and Apply Relighting Process", "PVBAI1"), ], - "default": "KI", - "letters_filter": "PVBXAKI", + "default": "PVBKI", + "letters_filter": "PVBXAKI1", "label": "Type of Process", "show_label" : False, } + + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A", "XA"], + "visible": False + } + extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["keep_frames_video_guide_not_supported"] = True extra_model_def["extract_guide_from_window_start"] = True @@ -480,8 +496,8 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "video_prompt_type": "UV", }) elif base_model_type in ["animate"]: - ui_defaults.update({ - "video_prompt_type": "PVBXAKI", + ui_defaults.update({ + "video_prompt_type": "PVBKI", "mask_expand": 20, "audio_prompt_type": "R", }) diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py index 58cc9a993..e0cec6a45 100644 --- a/shared/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -99,7 +99,7 @@ def is_float(element: any) -> bool: return loras_list_mult_choices_nums, slists_dict, "" -def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ): +def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None): from mmgp import offload sz = len(slists_dict["phase1"]) slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ] @@ -108,7 +108,8 @@ def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_st -def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): +def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): + total_num_steps = len(timesteps) model_switch_step = model_switch_step2 = None for i, t in enumerate(timesteps): if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i diff --git a/wgp.py b/wgp.py index d1eae3b58..def856ff6 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.73" +WanGP_version = "8.74" settings_version = 2.36 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -3699,13 +3699,15 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr any_mask = input_mask_path != None video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) if len(video) == 0: return None + frame_height, frame_width, _ = video[0].shape + num_frames = len(video) if any_mask: mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) - frame_height, frame_width, _ = video[0].shape - - num_frames = min(len(video), len(mask_video)) + num_frames = min(num_frames, len(mask_video)) if num_frames == 0: return None - video, mask_video = video[:num_frames], mask_video[:num_frames] + video = video[:num_frames] + if any_mask: + mask_video = mask_video[:num_frames] from preprocessing.face_preprocessor import FaceProcessor face_processor = FaceProcessor() @@ -6970,39 +6972,49 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t all_guide_processes ="PDESLCMUVB" -def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): +def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): + model_type = state["model_type"] + model_def = get_model_def(model_type) old_video_prompt_type = video_prompt_type - video_prompt_type = del_in_sequence(video_prompt_type, all_guide_processes) + if filter_type == "alt": + guide_custom_choices = model_def.get("guide_custom_choices",{}) + letter_filter = guide_custom_choices.get("letters_filter","") + else: + letter_filter = all_guide_processes + video_prompt_type = del_in_sequence(video_prompt_type, letter_filter) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type - model_type = state["model_type"] - model_def = get_model_def(model_type) any_outpainting= image_mode in model_def.get("video_guide_outpainting", []) mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type image_outputs = image_mode > 0 keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) + # mask_video_input_visible = image_mode == 0 and mask_visible mask_preprocessing = model_def.get("mask_preprocessing", None) if mask_preprocessing is not None: mask_selector_visible = mask_preprocessing.get("visible", True) else: mask_selector_visible = True - return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible) - - -def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode): - model_def = get_model_def(state["model_type"]) - guide_custom_choices = model_def.get("guide_custom_choices",{}) - video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter","")) - video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) - control_video_visible = "V" in video_prompt_type ref_images_visible = "I" in video_prompt_type - denoising_strength_visible = "G" in video_prompt_type - return video_prompt_type, gr.update(visible = control_video_visible and image_mode ==0), gr.update(visible = control_video_visible and image_mode >=1), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible ) - -# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): -# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] -# return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide) + return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ) + + +# def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): +# old_video_prompt_type = video_prompt_type +# model_def = get_model_def(state["model_type"]) +# guide_custom_choices = model_def.get("guide_custom_choices",{}) +# video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter","")) +# video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) +# image_outputs = image_mode > 0 +# control_video_visible = "V" in video_prompt_type +# ref_images_visible = "I" in video_prompt_type +# denoising_strength_visible = "G" in video_prompt_type +# mask_expand_visible = control_video_visible and "A" in video_prompt_type and not "U" in video_prompt_type +# mask_video_input_visible = image_mode == 0 and mask_expand_visible +# image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) +# keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) + +# return video_prompt_type, gr.update(visible = control_video_visible and image_mode ==0), gr.update(visible = control_video_visible and image_mode >=1), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible ), gr.update(visible = mask_video_input_visible ), gr.update(visible = mask_expand_visible), image_mask_guide, image_guide, image_mask, gr.update(visible = keep_frames_video_guide_visible) def refresh_preview(state): gen = get_gen_info(state) @@ -8294,9 +8306,10 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" ) image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] ) video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden") - video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand], show_progress="hidden") - video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength ], show_progress="hidden") - video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden") + video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden") + video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden") + # video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength, video_mask, mask_expand, image_mask_guide, image_guide, image_mask, keep_frames_video_guide ], show_progress="hidden") + video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden") video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden") video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) From bb7e84719dc148b7faef04c33836957f275e87a8 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 26 Sep 2025 00:17:11 +0200 Subject: [PATCH 042/155] better packaging of matanyone --- preprocessing/matanyone/app.py | 3 ++- wgp.py | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index df71b40a9..a5d557099 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -697,7 +697,8 @@ def load_unload_models(selected): model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device) model_in_GPU = True from .matanyone.model.matanyone import MatAnyone - matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") + # matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") + matanyone_model = MatAnyone.from_pretrained("ckpts/mask") # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model } # offload.profile(pipe) matanyone_model = matanyone_model.to("cpu").eval() diff --git a/wgp.py b/wgp.py index def856ff6..e00660c89 100644 --- a/wgp.py +++ b/wgp.py @@ -2519,6 +2519,7 @@ def download_mmaudio(): } process_files_def(**enhancer_def) +download_shared_done = False def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1): def computeList(filename): if filename == None: @@ -2536,7 +2537,7 @@ def computeList(filename): "repoId" : "DeepBeepMeep/Wan2.1", "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "roformer", "pyannote", "det_align", "" ], "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], - ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], + ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors", "model.safetensors", "config.json"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], ["config.json", "pytorch_model.bin", "preprocessor_config.json"], ["model_bs_roformer_ep_317_sdr_12.9755.ckpt", "model_bs_roformer_ep_317_sdr_12.9755.yaml", "download_checks.json"], @@ -2562,6 +2563,9 @@ def computeList(filename): process_files_def(**enhancer_def) download_mmaudio() + global download_shared_done + download_shared_done = True + if model_filename is None: return def download_file(url,filename): @@ -9138,6 +9142,8 @@ def set_new_tab(tab_state, new_tab_no): tab_state["tab_no"] = 0 return gr.Tabs(selected="video_gen") else: + if not download_shared_done: + download_models() vmc_event_handler(True) tab_state["tab_no"] = new_tab_no return gr.Tabs() From 6dfd17315246f2408e3232260bcf6c5a66cec8b4 Mon Sep 17 00:00:00 2001 From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com> Date: Sat, 27 Sep 2025 12:57:04 +0200 Subject: [PATCH 043/155] Update README.md added docker info --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/README.md b/README.md index d9be0fdf7..3ee2d88dc 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,32 @@ git pull pip install -r requirements.txt ``` +## 🐳 Docker: + +**For Debian-based systems (Ubuntu, Debian, etc.):** + +```bash +./run-docker-cuda-deb.sh +``` + +This automated script will: + +- Detect your GPU model and VRAM automatically +- Select optimal CUDA architecture for your GPU +- Install NVIDIA Docker runtime if needed +- Build a Docker image with all dependencies +- Run WanGP with optimal settings for your hardware + +**Docker environment includes:** + +- NVIDIA CUDA 12.4.1 with cuDNN support +- PyTorch 2.6.0 with CUDA 12.4 support +- SageAttention compiled for your specific GPU architecture +- Optimized environment variables for performance (TF32, threading, etc.) +- Automatic cache directory mounting for faster subsequent runs +- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally + +**Supported GPUs:** RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more. ## 📦 Installation From 79df3aae64c2c3b28286d921a6e7b13ee02a375e Mon Sep 17 00:00:00 2001 From: deepbeepmeep <84379123+deepbeepmeep@users.noreply.github.com> Date: Sat, 27 Sep 2025 13:00:08 +0200 Subject: [PATCH 044/155] Update README.md --- README.md | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/README.md b/README.md index 35fc0dc59..8261540ce 100644 --- a/README.md +++ b/README.md @@ -122,33 +122,6 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)** ## 🚀 Quick Start -### 🐳 Docker: - -**For Debian-based systems (Ubuntu, Debian, etc.):** - -```bash -./run-docker-cuda-deb.sh -``` - -This automated script will: - -- Detect your GPU model and VRAM automatically -- Select optimal CUDA architecture for your GPU -- Install NVIDIA Docker runtime if needed -- Build a Docker image with all dependencies -- Run WanGP with optimal settings for your hardware - -**Docker environment includes:** - -- NVIDIA CUDA 12.4.1 with cuDNN support -- PyTorch 2.6.0 with CUDA 12.4 support -- SageAttention compiled for your specific GPU architecture -- Optimized environment variables for performance (TF32, threading, etc.) -- Automatic cache directory mounting for faster subsequent runs -- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally - -**Supported GPUs:** RTX 50XX, RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more. - **One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/) **Manual installation:** From 5ce8fc3d535c9285d23c7ca06b3016527ba59743 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sat, 27 Sep 2025 15:14:55 +0200 Subject: [PATCH 045/155] various fixes --- README.md | 7 ++++--- models/qwen/pipeline_qwenimage.py | 4 ++-- models/wan/any2video.py | 11 ++++++++++- models/wan/modules/model.py | 16 +++++++++++----- models/wan/wan_handler.py | 9 +++++++++ shared/attention.py | 13 +++++++++++++ shared/sage2_core.py | 6 ++++++ wgp.py | 16 +++++++++++----- 8 files changed, 66 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 8261540ce..939cb7f04 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,9 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)** ## 🚀 Quick Start -**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/) +**One-click installation:** +- Get started instantly with [Pinokio App](https://pinokio.computer/) +- Use Redtash1 [One Click Install with Sage](https://github.com/Redtash1/Wan2GP-Windows-One-Click-Install-With-Sage) **Manual installation:** ```bash @@ -136,8 +138,7 @@ pip install -r requirements.txt **Run the application:** ```bash -python wgp.py # Text-to-video (default) -python wgp.py --i2v # Image-to-video +python wgp.py ``` **Update the application:** diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 02fd473b4..c458852fe 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -232,7 +232,7 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] - if self.processor is not None and image is not None: + if self.processor is not None and image is not None and len(image) > 0: img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" if isinstance(image, list): base_img_prompt = "" @@ -972,7 +972,7 @@ def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): if callback is not None: preview = self._unpack_latents(latents, height, width, self.vae_scale_factor) - preview = preview.squeeze(0) + preview = preview.transpose(0,2).squeeze(0) callback(i, preview, False) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 285340d39..7dc4b19c8 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -595,7 +595,7 @@ def generate(self, color_reference_frame = input_frames[:, -1:].clone() if prefix_frames_count > 0: overlapped_frames_num = prefix_frames_count - overlapped_latents_frames_num = (overlapped_latents_frames_num -1 // 4) + 1 + overlapped_latents_frames_num = (overlapped_frames_num -1 // 4) + 1 # overlapped_latents_frames_num = overlapped_latents.shape[2] # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 else: @@ -735,11 +735,20 @@ def generate(self, if callback != None: callback(-1, None, True) + + clear_caches() offload.shared_state["_chipmunk"] = False chipmunk = offload.shared_state.get("_chipmunk", False) if chipmunk: self.model.setup_chipmunk() + offload.shared_state["_radial"] = offload.shared_state["_attention"]=="radial" + radial = offload.shared_state.get("_radial", False) + if radial: + radial_cache = get_cache("radial") + from shared.radial_attention.attention import fill_radial_cache + fill_radial_cache(radial_cache, len(self.model.blocks), *target_shape[1:]) + # init denoising updated_num_steps= len(timesteps) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index cd024704c..82a313cec 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -315,6 +315,8 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks x_ref_attn_map = None chipmunk = offload.shared_state.get("_chipmunk", False) + radial = offload.shared_state.get("_radial", False) + if chipmunk and self.__class__ == WanSelfAttention: q = q.transpose(1,2) k = k.transpose(1,2) @@ -322,12 +324,17 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks attn_layers = offload.shared_state["_chipmunk_layers"] x = attn_layers[self.block_no](q, k, v) x = x.transpose(1,2) + elif radial and self.__class__ == WanSelfAttention: + qkv_list = [q,k,v] + del q,k,v + radial_cache = get_cache("radial") + no_step_no = offload.shared_state["step_no"] + x = radial_cache[self.block_no](qkv_list=qkv_list, timestep_no=no_step_no) elif block_mask == None: qkv_list = [q,k,v] del q,k,v - x = pay_attention( - qkv_list, - window_size=self.window_size) + + x = pay_attention( qkv_list, window_size=self.window_size) else: with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): @@ -1311,9 +1318,8 @@ def forward( del causal_mask offload.shared_state["embed_sizes"] = grid_sizes - offload.shared_state["step_no"] = real_step_no + offload.shared_state["step_no"] = current_step_no offload.shared_state["max_steps"] = max_steps - if current_step_no == 0 and x_id == 0: clear_caches() # arguments kwargs = dict( diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index c5875445b..1a30fffbc 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -184,6 +184,15 @@ def query_model_def(base_model_type, model_def): "visible": False } + # extra_model_def["image_ref_choices"] = { + # "choices": [("None", ""), + # ("People / Objects", "I"), + # ("Landscape followed by People / Objects (if any)", "KI"), + # ], + # "visible": False, + # "letters_filter": "KFI", + # } + extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["keep_frames_video_guide_not_supported"] = True extra_model_def["extract_guide_from_window_start"] = True diff --git a/shared/attention.py b/shared/attention.py index cc6ece02d..efaa22370 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -42,6 +42,11 @@ def sageattn_varlen_wrapper( except ImportError: sageattn_varlen_wrapper = None +try: + from spas_sage_attn import block_sparse_sage2_attn_cuda +except ImportError: + block_sparse_sage2_attn_cuda = None + try: from .sage2_core import sageattn as sageattn2, is_sage2_supported @@ -62,6 +67,8 @@ def sageattn2_wrapper( return o +from sageattn import sageattn_blackwell as sageattn3 + try: from sageattn import sageattn_blackwell as sageattn3 except ImportError: @@ -144,6 +151,9 @@ def get_attention_modes(): ret.append("sage") if sageattn2 != None and version("sageattention").startswith("2") : ret.append("sage2") + if block_sparse_sage2_attn_cuda != None and version("sageattention").startswith("2") : + ret.append("radial") + if sageattn3 != None: # and version("sageattention").startswith("3") : ret.append("sage3") @@ -159,6 +169,8 @@ def get_supported_attention_modes(): if not sage2_supported: if "sage2" in ret: ret.remove("sage2") + if "radial" in ret: + ret.remove("radial") if major < 7: if "sage" in ret: @@ -225,6 +237,7 @@ def pay_attention( if attn == "chipmunk": from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG + if attn == "radial": attn ="sage2" if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"): assert attention_mask == None diff --git a/shared/sage2_core.py b/shared/sage2_core.py index 04b77f8bb..33cd98cbf 100644 --- a/shared/sage2_core.py +++ b/shared/sage2_core.py @@ -28,18 +28,24 @@ try: from sageattention import _qattn_sm80 + if not hasattr(_qattn_sm80, "qk_int8_sv_f16_accum_f32_attn"): + _qattn_sm80 = torch.ops.sageattention_qattn_sm80 SM80_ENABLED = True except: SM80_ENABLED = False try: from sageattention import _qattn_sm89 + if not hasattr(_qattn_sm89, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): + _qattn_sm89 = torch.ops.sageattention_qattn_sm89 SM89_ENABLED = True except: SM89_ENABLED = False try: from sageattention import _qattn_sm90 + if not hasattr(_qattn_sm90, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): + _qattn_sm90 = torch.ops.sageattention_qattn_sm90 SM90_ENABLED = True except: SM90_ENABLED = False diff --git a/wgp.py b/wgp.py index e00660c89..6a257e814 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.74" +WanGP_version = "8.75" settings_version = 2.36 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -1394,6 +1394,12 @@ def _parse_args(): help="View form generation / refresh time" ) + parser.add_argument( + "--betatest", + action="store_true", + help="test unreleased features" + ) + parser.add_argument( "--vram-safety-coefficient", type=float, @@ -4367,7 +4373,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri if image_start is None or not "I" in prompt_enhancer: image_start = [None] * num_prompts else: - image_start = [img[0] for img in image_start] + image_start = [convert_image(img[0]) for img in image_start] if len(image_start) == 1: image_start = image_start * num_prompts else: @@ -8712,9 +8718,9 @@ def check(mode): ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"), ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), - ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), - ("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3"), - ], + ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2")]\ + + ([("Radial" + check("radial")+ ": x? faster but ? quality - requires Sparge & Sage 2 Attn (usually complex to set up on Windows without WSL)", "radial")] if args.betatest else [])\ + + [("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3")], value= attention_mode, label="Attention Type", interactive= not lock_ui_attention From d62748a226e4eab35cdf500231e1234893a5ea09 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sat, 27 Sep 2025 15:22:31 +0200 Subject: [PATCH 046/155] missed files --- shared/radial_attention/attention.py | 49 ++ shared/radial_attention/attn_mask.py | 379 ++++++++++++++++ shared/radial_attention/inference.py | 41 ++ shared/radial_attention/sparse_transformer.py | 424 ++++++++++++++++++ shared/radial_attention/utils.py | 16 + 5 files changed, 909 insertions(+) create mode 100644 shared/radial_attention/attention.py create mode 100644 shared/radial_attention/attn_mask.py create mode 100644 shared/radial_attention/inference.py create mode 100644 shared/radial_attention/sparse_transformer.py create mode 100644 shared/radial_attention/utils.py diff --git a/shared/radial_attention/attention.py b/shared/radial_attention/attention.py new file mode 100644 index 000000000..ab4a5a293 --- /dev/null +++ b/shared/radial_attention/attention.py @@ -0,0 +1,49 @@ +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange +from .attn_mask import RadialAttention, MaskMap + +def fill_radial_cache(radial_cache, nb_layers, lat_t, lat_h, lat_w): + MaskMap._log_mask = None + + for i in range(nb_layers): + radial_cache[i] = WanSparseAttnProcessor2_0(i, lat_t, lat_h, lat_w) + +class WanSparseAttnProcessor2_0: + mask_map = None + dense_timestep = 0 + dense_block = 0 + decay_factor = 1.0 + sparse_type = "radial" # default to radial attention, can be changed to "dense" for dense attention + use_sage_attention = True + + def __init__(self, layer_idx, lat_t, lat_h, lat_w): + self.layer_idx = layer_idx + self.mask_map = MaskMap(video_token_num=lat_t * lat_h * lat_w // 4 , num_frame=lat_t) + def __call__( + self, + qkv_list, + timestep_no = 0, + ) -> torch.Tensor: + query, key, value = qkv_list + + batch_size = query.shape[0] + # transform (batch_size, seq_len, num_heads, head_dim) to (seq_len * batch_size, num_heads, head_dim) + query = rearrange(query, "b s h d -> (b s) h d") + key = rearrange(key, "b s h d -> (b s) h d") + value = rearrange(value, "b s h d -> (b s) h d") + if timestep_no < self.dense_timestep or self.layer_idx < self.dense_block or self.sparse_type == "dense": + hidden_states = RadialAttention( + query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="dense", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention + ) + else: + # apply radial attention + hidden_states = RadialAttention( + query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="radial", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention + ) + # transform back to (batch_size, num_heads, seq_len, head_dim) + hidden_states = rearrange(hidden_states, "(b s) h d -> b s h d", b=batch_size) + + return hidden_states diff --git a/shared/radial_attention/attn_mask.py b/shared/radial_attention/attn_mask.py new file mode 100644 index 000000000..4c29f1c32 --- /dev/null +++ b/shared/radial_attention/attn_mask.py @@ -0,0 +1,379 @@ +import torch +# import flashinfer +import matplotlib.pyplot as plt +# from sparse_sageattn import sparse_sageattn +from einops import rearrange, repeat +from sageattention import sageattn +from spas_sage_attn import block_sparse_sage2_attn_cuda + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + +from spas_sage_attn import block_sparse_sage2_attn_cuda + +def sparge_mask_convert(mask: torch.Tensor, block_size: int = 128, arch="sm") -> torch.Tensor: + assert block_size in [128, 64], "Radial Attention only supports block size of 128 or 64" + assert mask.shape[0] == mask.shape[1], "Input mask must be square." + + if block_size == 128: + if arch == "sm90": + new_mask = torch.repeat_interleave(mask, 2, dim=0) + else: + new_mask = torch.repeat_interleave(mask, 2, dim=1) + + elif block_size == 64: + if arch == "sm90": + num_row, num_col = mask.shape + reshaped_mask = mask.view(num_row, num_col // 2, 2) + new_mask = torch.max(reshaped_mask, dim=2).values + else: + num_row, num_col = mask.shape + reshaped_mask = mask.view(num_row // 2, 2, num_col) + new_mask = torch.max(reshaped_mask, dim=1).values + + return new_mask + +def get_indptr_from_mask(mask, query): + # query shows the device of the indptr + # indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension, + # shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension. + # The first element is always 0, and the last element is the number of blocks in the row dimension. + # The rest of the elements are the number of blocks in each row. + # the mask is already a block sparse mask + indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32) + indptr[0] = 0 + row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row] + indptr[1:] = torch.cumsum(row_counts, dim=0) + return indptr + +def get_indices_from_mask(mask, query): + # indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension, + # shape `(nnz,),` where `nnz` is the number of non-zero blocks. + # The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension. + nonzero_indices = torch.nonzero(mask) + indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device) + return indices + +def shrinkMaskStrict(mask, block_size=128): + seqlen = mask.shape[0] + block_num = seqlen // block_size + mask = mask[:block_num * block_size, :block_num * block_size].view(block_num, block_size, block_num, block_size) + col_densities = mask.sum(dim = 1) / block_size + # we want the minimum non-zero column density in the block + non_zero_densities = col_densities > 0 + high_density_cols = col_densities > 1/3 + frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9) + block_mask = frac_high_density_cols > 0.6 + block_mask[0:0] = True + block_mask[-1:-1] = True + return block_mask + +def pad_qkv(input_tensor, block_size=128): + """ + Pad the input tensor to be a multiple of the block size. + input shape: (seqlen, num_heads, hidden_dim) + """ + seqlen, num_heads, hidden_dim = input_tensor.shape + # Calculate the necessary padding + padding_length = (block_size - (seqlen % block_size)) % block_size + # Create a padded tensor with zeros + padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype) + # Copy the original tensor into the padded tensor + padded_tensor[:seqlen, :, :] = input_tensor + + return padded_tensor + +def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query): + assert(sparse_type in ["radial"]) + dist = abs(i - j) + group = dist.bit_length() + threshold = 128 # hardcoded threshold for now, which is equal to block-size + decay_length = 2 ** token_per_frame.bit_length() / 2 ** group + if decay_length >= threshold: + return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + + split_factor = int(threshold / decay_length) + modular = dist % split_factor + if modular == 0: + return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + else: + return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + +def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None): + assert(sparse_type in ["radial"]) + dist = abs(i - j) + if model_type == "wan": + if dist < 1: + return token_per_frame + if dist == 1: + return token_per_frame // 2 + elif model_type == "hunyuan": + if dist <= 1: + return token_per_frame + else: + raise ValueError(f"Unknown model type: {model_type}") + group = dist.bit_length() + decay_length = 2 ** token_per_frame.bit_length() / 2 ** group * decay_factor + threshold = block_size + if decay_length >= threshold: + return decay_length + else: + return threshold + +def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None): + """ + A more memory friendly version, we generate the attention mask of each frame pair at a time, + shrinks it, and stores it into the final result + """ + final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool) + token_per_frame = video_token_num // num_frame + video_text_border = video_token_num // block_size + + col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1) + row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1) + final_log_mask[video_text_border:] = True + final_log_mask[:, video_text_border:] = True + for i in range(num_frame): + for j in range(num_frame): + local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + if j == 0 and model_type == "wan": # this is attention sink + local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + else: + window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type) + local_mask = torch.abs(col_indices - row_indices) <= window_width + split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query) + local_mask = torch.logical_and(local_mask, split_mask) + + remainder_row = (i * token_per_frame) % block_size + remainder_col = (j * token_per_frame) % block_size + # get the padded size + all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size + all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size + padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool) + padded_local_mask[remainder_row:remainder_row + token_per_frame, remainder_col:remainder_col + token_per_frame] = local_mask + # shrink the mask + block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size) + # set the block mask to the final log mask + block_row_start = (i * token_per_frame) // block_size + block_col_start = (j * token_per_frame) // block_size + block_row_end = block_row_start + block_mask.shape[0] + block_col_end = block_col_start + block_mask.shape[1] + final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or( + final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask) + print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}") + return final_log_mask + +class MaskMap: + _log_mask = None + + def __init__(self, video_token_num=25440, num_frame=16): + self.video_token_num = video_token_num + self.num_frame = num_frame + + def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None): + if MaskMap._log_mask is None: + MaskMap._log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool) + MaskMap._log_mask = gen_log_mask_shrinked(query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size) + return MaskMap._log_mask + +def SpargeSageAttnBackend(query, key, value, mask_map=None, video_mask=None, pre_defined_mask=None, block_size=128): + if video_mask.all(): + # dense case + kv_border = pre_defined_mask[0].sum() if pre_defined_mask is not None else key.shape[0] + output_video = sageattn( + query[:mask_map.video_token_num, :, :].unsqueeze(0), + key[:kv_border, :, :].unsqueeze(0), + value[:kv_border, :, :].unsqueeze(0), + tensor_layout="NHD", + )[0] + + if pre_defined_mask is not None: + output_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key[:pre_defined_mask[0].sum(), :, :], + v=value[:pre_defined_mask[0].sum(), :, :], + causal=False, + return_lse=False, + ) + return torch.cat([output_video, output_text], dim=0) + else: + return output_video + + # sparse-sageattention only supports (b, h, s, d) layout, need rearrange first + query_hnd = rearrange(query.unsqueeze(0), "b s h d -> b h s d") + key_hnd = rearrange(key.unsqueeze(0), "b s h d -> b h s d") + value_hnd = rearrange(value.unsqueeze(0), "b s h d -> b h s d") + arch = get_cuda_arch_versions()[query.device.index] + converted_mask = repeat(sparge_mask_convert(mask=video_mask, block_size=block_size, arch=arch), "s t -> b h s t", b=query_hnd.shape[0], h=query_hnd.shape[1]) + + converted_mask = converted_mask.to(torch.int8) + if pre_defined_mask is None: + # wan case + output = block_sparse_sage2_attn_cuda( + query_hnd[:, :, :mask_map.video_token_num, :], + key_hnd[:, :, :mask_map.video_token_num, :], + value_hnd[:, :, :mask_map.video_token_num, :], + mask_id=converted_mask, + tensor_layout="HND", + ) + + # rearrange back to (s, h, d), we know that b = 1 + output = rearrange(output, "b h s d -> s (b h) d", b=1) + return output + + query_video = query_hnd[:, :, :mask_map.video_token_num, :] + key_video = key_hnd + value_video = value_hnd + kv_border = (pre_defined_mask[0].sum() + 63) // 64 + converted_mask[:, :, :, kv_border:] = False + output_video = block_sparse_sage2_attn_cuda( + query_video, + key_video, + value_video, + mask_id=converted_mask[:, :, :mask_map.video_token_num // block_size, :].contiguous(), + tensor_layout="HND", + ) + + # rearrange back to (s, h, d), we know that b = 1 + output_video = rearrange(output_video, "b h s d -> s (b h) d", b=1) + + # gt = sparse_sageattn( + # query_video, + # key_video, + # value_video, + # mask_id=None, + # is_causal=False, + # tensor_layout="HND", + # )[0] + + + + # import pdb; pdb.set_trace() + + output_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key[:pre_defined_mask[0].sum(), :, :], + v=value[:pre_defined_mask[0].sum(), :, :], + causal=False, + return_lse=False, + ) + + return torch.cat([output_video, output_text], dim=0) + + +def FlashInferBackend(query, key, value, mask_map=None, pre_defined_mask=None, bsr_wrapper=None): + if pre_defined_mask is not None: + video_video_o, video_video_o_lse = bsr_wrapper.run( + query[:mask_map.video_token_num, :, :], + key[:mask_map.video_token_num, :, :], + value[:mask_map.video_token_num, :, :], + return_lse=True + ) + # perform non-causal flashinfer on the text tokens + video_text_o, video_text_o_lse = flashinfer.single_prefill_with_kv_cache( + q=query[:mask_map.video_token_num, :, :], + k=key[mask_map.video_token_num:, :, :], + v=value[mask_map.video_token_num:, :, :], + causal=False, + return_lse=True, + custom_mask=pre_defined_mask[:mask_map.video_token_num, mask_map.video_token_num:] + ) + + # merge the two results + o_video, _ = flashinfer.merge_state(v_a=video_video_o, s_a=video_video_o_lse, v_b=video_text_o, s_b=video_text_o_lse) + + o_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key, + v=value, + causal=False, + return_lse=False, + custom_mask=pre_defined_mask[mask_map.video_token_num:, :] + ) + + return torch.cat([o_video, o_text], dim=0) + else: + o = bsr_wrapper.run( + query[:mask_map.video_token_num, :, :], + key[:mask_map.video_token_num, :, :], + value[:mask_map.video_token_num, :, :] + ) + return o + +def RadialAttention(query, key, value, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_type=None, pre_defined_mask=None, use_sage_attention=False): + orig_seqlen, num_head, hidden_dim = query.shape + + if sparsity_type == "dense": + video_mask = torch.ones((mask_map.video_token_num // block_size, mask_map.video_token_num // block_size), device=query.device, dtype=torch.bool) + else: + video_mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_type) if mask_map else None + + backend = "sparse_sageattn" if use_sage_attention else "flashinfer" + + if backend == "flashinfer": + video_mask = video_mask[:mask_map.video_token_num // block_size, :mask_map.video_token_num // block_size] + # perform block-sparse attention on the video tokens + workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8) + bsr_wrapper = flashinfer.BlockSparseAttentionWrapper( + workspace_buffer, + backend="fa2", + ) + + indptr = get_indptr_from_mask(video_mask, query) + indices = get_indices_from_mask(video_mask, query) + + bsr_wrapper.plan( + indptr=indptr, + indices=indices, + M=mask_map.video_token_num, + N=mask_map.video_token_num, + R=block_size, + C=block_size, + num_qo_heads=num_head, + num_kv_heads=num_head, + head_dim=hidden_dim, + q_data_type=query.dtype, + kv_data_type=key.dtype, + o_data_type=query.dtype, + ) + + return FlashInferBackend(query, key, value, mask_map, pre_defined_mask, bsr_wrapper) + elif backend == "sparse_sageattn": + return SpargeSageAttnBackend(query, key, value, mask_map, video_mask, pre_defined_mask, block_size=block_size) + +if __name__ == "__main__": + query = torch.randn(1, 2, 4, 64).cuda() + # mask = torch.tensor([ + # [True, False, True, False], + # [False, True, False, True], + # [True, False, False, True], + # [False, True, True, False] + # ], dtype=torch.bool) + # indices = get_indices_from_mask(mask, query) + # indptr = get_indptr_from_mask(mask, query) + # print("Indices: ", indices) + # print("Indptr: ", indptr) + video_token_num = 3840 * 30 + num_frame = 30 + token_per_frame = video_token_num / num_frame + padded_video_token_num = ((video_token_num + 1) // 128 + 1) * 128 + print("padded: ", padded_video_token_num) + temporal_mask = gen_log_mask_shrinked(query, padded_video_token_num, video_token_num, num_frame, sparse_type="radial", decay_factor=1, model_type="hunyuan") + plt.figure(figsize=(10, 8), dpi=500) + + plt.imshow(temporal_mask.cpu().numpy()[:, :], cmap='hot') + plt.colorbar() + plt.title("Temporal Mask") + + plt.savefig("temporal_mask.png", + dpi=300, + bbox_inches='tight', + pad_inches=0.1) + + plt.close() + # save the mask tensor + torch.save(temporal_mask, "temporal_mask.pt") diff --git a/shared/radial_attention/inference.py b/shared/radial_attention/inference.py new file mode 100644 index 000000000..04ef186ab --- /dev/null +++ b/shared/radial_attention/inference.py @@ -0,0 +1,41 @@ +import torch +from diffusers.models.attention_processor import Attention +from diffusers.models.attention import AttentionModuleMixin +from .attention import WanSparseAttnProcessor +from .attn_mask import MaskMap + +def setup_radial_attention( + pipe, + height, + width, + num_frames, + dense_layers=0, + dense_timesteps=0, + decay_factor=1.0, + sparsity_type="radial", + use_sage_attention=False, +): + + num_frames = 1 + num_frames // (pipe.vae_scale_factor_temporal * pipe.transformer.config.patch_size[0]) + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + frame_size = int(height // mod_value) * int(width // mod_value) + + AttnModule = WanSparseAttnProcessor + AttnModule.dense_block = dense_layers + AttnModule.dense_timestep = dense_timesteps + AttnModule.mask_map = MaskMap(video_token_num=frame_size * num_frames, num_frame=num_frames) + AttnModule.decay_factor = decay_factor + AttnModule.sparse_type = sparsity_type + AttnModule.use_sage_attention = use_sage_attention + + print(f"Replacing Wan attention with {sparsity_type} attention") + print(f"video token num: {AttnModule.mask_map.video_token_num}, num frames: {num_frames}") + print(f"dense layers: {dense_layers}, dense timesteps: {dense_timesteps}, decay factor: {decay_factor}") + + for layer_idx, m in enumerate(pipe.transformer.blocks): + m.attn1.processor.layer_idx = layer_idx + + for _, m in pipe.transformer.named_modules(): + if isinstance(m, AttentionModuleMixin) and hasattr(m.processor, 'layer_idx'): + layer_idx = m.processor.layer_idx + m.set_processor(AttnModule(layer_idx)) \ No newline at end of file diff --git a/shared/radial_attention/sparse_transformer.py b/shared/radial_attention/sparse_transformer.py new file mode 100644 index 000000000..96820d0f3 --- /dev/null +++ b/shared/radial_attention/sparse_transformer.py @@ -0,0 +1,424 @@ +# borrowed from svg-project/Sparse-VideoGen + +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from diffusers.models.transformers.transformer_wan import WanTransformerBlock, WanTransformer3DModel +from diffusers import WanPipeline +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +logger = logging.get_logger(__name__) +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +import torch.distributed as dist + +try: + from xfuser.core.distributed import ( + get_ulysses_parallel_world_size, + get_ulysses_parallel_rank, + get_sp_group + ) +except: + pass + +class WanTransformerBlock_Sparse(WanTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + numeral_timestep: Optional[int] = None, + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, numerical_timestep=numeral_timestep) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states).contiguous() + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + +class WanTransformer3DModel_Sparse(WanTransformer3DModel): + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + numeral_timestep: Optional[int] = None, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + ) + + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + # split video latents on dim TS + hidden_states = torch.chunk(hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()] + rotary_emb = ( + torch.chunk(rotary_emb[0], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()], + torch.chunk(rotary_emb[1], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()], + ) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, numeral_timestep=numeral_timestep + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + numeral_timestep=numeral_timestep, + ) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + +class WanPipeline_Sparse(WanPipeline): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + numeral_timestep=i, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + numeral_timestep=i, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + +def replace_sparse_forward(): + WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward + WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward + WanPipeline.__call__ = WanPipeline_Sparse.__call__ \ No newline at end of file diff --git a/shared/radial_attention/utils.py b/shared/radial_attention/utils.py new file mode 100644 index 000000000..92b3c3e13 --- /dev/null +++ b/shared/radial_attention/utils.py @@ -0,0 +1,16 @@ +import os +import random +import numpy as np +import torch + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False \ No newline at end of file From a4d62fc1cd3e37b5f4dcb27aa189a4c155ca715e Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sat, 27 Sep 2025 16:30:13 +0200 Subject: [PATCH 047/155] removed test line that cause sage crash --- shared/attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/shared/attention.py b/shared/attention.py index efaa22370..c31087c24 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -67,8 +67,6 @@ def sageattn2_wrapper( return o -from sageattn import sageattn_blackwell as sageattn3 - try: from sageattn import sageattn_blackwell as sageattn3 except ImportError: From 63db8448f00a718e68a0605214bd044b08cbfbe9 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 18:13:09 +0200 Subject: [PATCH 048/155] Animate upgraded --- models/wan/any2video.py | 35 ++- models/wan/lynx/__init__.py | 0 models/wan/lynx/flash_attention.py | 60 ++++ models/wan/lynx/full_attention_processor.py | 305 +++++++++++++++++++ models/wan/lynx/inference_utils.py | 39 +++ models/wan/lynx/light_attention_processor.py | 194 ++++++++++++ models/wan/lynx/model_utils_wan.py | 264 ++++++++++++++++ models/wan/lynx/navit_utils.py | 16 + models/wan/lynx/resampler.py | 185 +++++++++++ models/wan/modules/model.py | 78 ++++- models/wan/wan_handler.py | 48 ++- preprocessing/arc/face_encoder.py | 97 ++++++ preprocessing/arc/face_utils.py | 63 ++++ requirements.txt | 2 + wgp.py | 108 +++++-- 15 files changed, 1456 insertions(+), 38 deletions(-) create mode 100644 models/wan/lynx/__init__.py create mode 100644 models/wan/lynx/flash_attention.py create mode 100644 models/wan/lynx/full_attention_processor.py create mode 100644 models/wan/lynx/inference_utils.py create mode 100644 models/wan/lynx/light_attention_processor.py create mode 100644 models/wan/lynx/model_utils_wan.py create mode 100644 models/wan/lynx/navit_utils.py create mode 100644 models/wan/lynx/resampler.py create mode 100644 preprocessing/arc/face_encoder.py create mode 100644 preprocessing/arc/face_utils.py diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 7dc4b19c8..b0fc24659 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -354,6 +354,7 @@ def generate(self, pre_video_frame = None, video_prompt_type= "", original_input_ref_images = [], + face_arc_embeds = None, **bbargs ): @@ -428,6 +429,7 @@ def generate(self, multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] infinitetalk = model_type in ["infinitetalk"] standin = model_type in ["standin", "vace_standin_14B"] + lynx = model_type in ["lynx_lite", "lynx_full", "vace_lynx_lite_14B", "vace_lynx_full_14B"] recam = model_type in ["recam_1.3B"] ti2v = model_type in ["ti2v_2_2", "lucy_edit"] lucy_edit= model_type in ["lucy_edit"] @@ -535,6 +537,7 @@ def generate(self, pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) input_frames = input_frames * input_masks if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames + # input_frames = input_frames[:, :1].expand(-1, input_frames.shape[1], -1, -1) if prefix_frames_count > 0: input_frames[:, :prefix_frames_count] = input_video input_masks[:, :prefix_frames_count] = 1 @@ -549,7 +552,8 @@ def generate(self, clip_image_start = image_ref.squeeze(1) lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1) y = torch.concat([msk, lat_y]) - kwargs.update({ 'y': y, 'pose_latents': pose_latents, 'face_pixel_values' : input_faces.unsqueeze(0)}) + kwargs.update({ 'y': y, 'pose_latents': pose_latents}) + face_pixel_values = input_faces.unsqueeze(0) lat_y = msk = msk_control = msk_ref = pose_pixels = None ref_images_before = True ref_images_count = 1 @@ -699,6 +703,23 @@ def generate(self, kwargs["freqs"] = freqs + # Lynx + if lynx: + from .lynx.resampler import Resampler + from accelerate import init_empty_weights + lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"] + with init_empty_weights(): + arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 ) + offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) + arc_resampler.to(self.device) + arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float) + ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype) + ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype) + arc_resampler = None + gc.collect() + torch.cuda.empty_cache() + kwargs["lynx_ip_scale"] = 1. + #Standin if standin: from preprocessing.face_preprocessor import FaceProcessor @@ -848,6 +869,18 @@ def clear(): "context" : [context, context_null, context_null], "audio_scale": [audio_scale, None, None ] } + elif animate: + gen_args = { + "x" : [latent_model_input, latent_model_input], + "context" : [context, context_null], + "face_pixel_values": [face_pixel_values, None] + } + elif lynx: + gen_args = { + "x" : [latent_model_input, latent_model_input], + "context" : [context, context_null], + "lynx_ip_embeds": [ip_hidden_states, ip_hidden_states_uncond] + } elif multitalk and audio_proj != None: if guide_scale == 1: gen_args = { diff --git a/models/wan/lynx/__init__.py b/models/wan/lynx/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/models/wan/lynx/flash_attention.py b/models/wan/lynx/flash_attention.py new file mode 100644 index 000000000..3bb807db8 --- /dev/null +++ b/models/wan/lynx/flash_attention.py @@ -0,0 +1,60 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import torch +try: + from flash_attn_interface import flash_attn_varlen_func # flash attn 3 +except: + from flash_attn.flash_attn_interface import flash_attn_varlen_func # flash attn 2 + +def flash_attention(query, key, value, q_lens, kv_lens, causal=False): + """ + Args: + query, key, value: [B, H, T, D_h] + q_lens: list[int] per sequence query length + kv_lens: list[int] per sequence key/value length + causal: whether to use causal mask + + Returns: + output: [B, H, T_q, D_h] + """ + B, H, T_q, D_h = query.shape + T_k = key.shape[2] + device = query.device + + # Flatten: [B, H, T, D] -> [total_tokens, H, D] + q = query.permute(0, 2, 1, 3).reshape(B * T_q, H, D_h) + k = key.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h) + v = value.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h) + + # Prepare cu_seqlens: prefix sum + q_lens_tensor = torch.tensor(q_lens, device=device, dtype=torch.int32) + kv_lens_tensor = torch.tensor(kv_lens, device=device, dtype=torch.int32) + + cu_seqlens_q = torch.zeros(len(q_lens_tensor) + 1, device=device, dtype=torch.int32) + cu_seqlens_k = torch.zeros(len(kv_lens_tensor) + 1, device=device, dtype=torch.int32) + + cu_seqlens_q[1:] = torch.cumsum(q_lens_tensor, dim=0) + cu_seqlens_k[1:] = torch.cumsum(kv_lens_tensor, dim=0) + + max_seqlen_q = int(q_lens_tensor.max().item()) + max_seqlen_k = int(kv_lens_tensor.max().item()) + + # Call FlashAttention varlen kernel + out = flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal=causal + ) + if not torch.is_tensor(out): # flash attn 3 + out = out[0] + + # Restore shape: [total_q, H, D_h] -> [B, H, T_q, D_h] + out = out.view(B, T_q, H, D_h).permute(0, 2, 1, 3).contiguous() + + return out \ No newline at end of file diff --git a/models/wan/lynx/full_attention_processor.py b/models/wan/lynx/full_attention_processor.py new file mode 100644 index 000000000..fd0edb7a8 --- /dev/null +++ b/models/wan/lynx/full_attention_processor.py @@ -0,0 +1,305 @@ +# Copyright (c) 2025 The Wan Team and The HuggingFace Team. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# +# Original file was released under Apache License 2.0, with the full license text +# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. +# +# This modified file is released under the same license. + + +from typing import Optional, List +import torch +import torch.nn as nn +from diffusers.models.attention_processor import Attention +from modules.common.flash_attention import flash_attention +from modules.common.navit_utils import vector_to_list, list_to_vector, merge_token_lists + +def new_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + **cross_attention_kwargs, +) -> torch.Tensor: + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) +Attention.forward = new_forward + +class WanIPAttnProcessor(nn.Module): + def __init__( + self, + cross_attention_dim: int, + dim: int, + n_registers: int, + bias: bool, + ): + super().__init__() + self.to_k_ip = nn.Linear(cross_attention_dim, dim, bias=bias) + self.to_v_ip = nn.Linear(cross_attention_dim, dim, bias=bias) + if n_registers > 0: + self.registers = nn.Parameter(torch.randn(1, n_registers, cross_attention_dim) / dim**0.5) + else: + self.registers = None + + def forward( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + ip_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + q_lens: List[int] = None, + kv_lens: List[int] = None, + ip_lens: List[int] = None, + ip_scale: float = 1.0, + ): + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + encoder_hidden_states_img = encoder_hidden_states[:, :257] + encoder_hidden_states = encoder_hidden_states[:, 257:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if ip_hidden_states is not None: + + if self.registers is not None: + ip_hidden_states_list = vector_to_list(ip_hidden_states, ip_lens, 1) + ip_hidden_states_list = merge_token_lists(ip_hidden_states_list, [self.registers] * len(ip_hidden_states_list), 1) + ip_hidden_states, ip_lens = list_to_vector(ip_hidden_states_list, 1) + + ip_query = query + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + if attn.norm_q is not None: + ip_query = attn.norm_q(ip_query) + if attn.norm_k is not None: + ip_key = attn.norm_k(ip_key) + ip_query = ip_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + ip_key = ip_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + ip_value = ip_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + ip_hidden_states = flash_attention( + ip_query, ip_key, ip_value, + q_lens=q_lens, + kv_lens=ip_lens + ) + ip_hidden_states = ip_hidden_states.transpose(1, 2).flatten(2, 3) + ip_hidden_states = ip_hidden_states.type_as(ip_query) + + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + hidden_states = flash_attention( + query, key, value, + q_lens=q_lens, + kv_lens=kv_lens, + ) + + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if ip_hidden_states is not None: + hidden_states = hidden_states + ip_scale * ip_hidden_states + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + +def register_ip_adapter( + model, + cross_attention_dim=None, + n_registers=0, + init_method="zero", + dtype=torch.float32, +): + attn_procs = {} + transformer_sd = model.state_dict() + for layer_idx, block in enumerate(model.blocks): + name = f"blocks.{layer_idx}.attn2.processor" + layer_name = name.split(".processor")[0] + dim = transformer_sd[layer_name + ".to_k.weight"].shape[1] + attn_procs[name] = WanIPAttnProcessor( + cross_attention_dim=dim if cross_attention_dim is None else cross_attention_dim, + dim=dim, + n_registers=n_registers, + bias=True, + ) + if init_method == "zero": + torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight) + torch.nn.init.zeros_(attn_procs[name].to_k_ip.bias) + torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight) + torch.nn.init.zeros_(attn_procs[name].to_v_ip.bias) + elif init_method == "clone": + layer_name = name.split(".processor")[0] + weights = { + "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"], + "to_k_ip.bias": transformer_sd[layer_name + ".to_k.bias"], + "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"], + "to_v_ip.bias": transformer_sd[layer_name + ".to_v.bias"], + } + attn_procs[name].load_state_dict(weights) + elif init_method == "random": + pass + else: + raise ValueError(f"{init_method} is not supported.") + block.attn2.processor = attn_procs[name] + ip_layers = torch.nn.ModuleList(attn_procs.values()) + ip_layers.to(device=model.device, dtype=dtype) + return model, ip_layers + + +class WanRefAttnProcessor(nn.Module): + def __init__( + self, + dim: int, + bias: bool, + ): + super().__init__() + self.to_k_ref = nn.Linear(dim, dim, bias=bias) + self.to_v_ref = nn.Linear(dim, dim, bias=bias) + + def forward( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + q_lens: List[int] = None, + kv_lens: List[int] = None, + ref_feature: Optional[tuple] = None, + ref_scale: float = 1.0, + ): + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + encoder_hidden_states_img = encoder_hidden_states[:, :257] + encoder_hidden_states = encoder_hidden_states[:, 257:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if ref_feature is None: + ref_hidden_states = None + else: + ref_hidden_states, ref_lens = ref_feature + ref_query = query + ref_key = self.to_k_ref(ref_hidden_states) + ref_value = self.to_v_ref(ref_hidden_states) + if attn.norm_q is not None: + ref_query = attn.norm_q(ref_query) + if attn.norm_k is not None: + ref_key = attn.norm_k(ref_key) + ref_query = ref_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + ref_key = ref_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + ref_value = ref_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + ref_hidden_states = flash_attention( + ref_query, ref_key, ref_value, + q_lens=q_lens, + kv_lens=ref_lens + ) + ref_hidden_states = ref_hidden_states.transpose(1, 2).flatten(2, 3) + ref_hidden_states = ref_hidden_states.type_as(ref_query) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + hidden_states = flash_attention( + query, key, value, + q_lens=q_lens, + kv_lens=kv_lens, + ) + + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if ref_hidden_states is not None: + hidden_states = hidden_states + ref_scale * ref_hidden_states + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states + + +def register_ref_adapter( + model, + init_method="zero", + dtype=torch.float32, +): + attn_procs = {} + transformer_sd = model.state_dict() + for layer_idx, block in enumerate(model.blocks): + name = f"blocks.{layer_idx}.attn1.processor" + layer_name = name.split(".processor")[0] + dim = transformer_sd[layer_name + ".to_k.weight"].shape[0] + attn_procs[name] = WanRefAttnProcessor( + dim=dim, + bias=True, + ) + if init_method == "zero": + torch.nn.init.zeros_(attn_procs[name].to_k_ref.weight) + torch.nn.init.zeros_(attn_procs[name].to_k_ref.bias) + torch.nn.init.zeros_(attn_procs[name].to_v_ref.weight) + torch.nn.init.zeros_(attn_procs[name].to_v_ref.bias) + elif init_method == "clone": + weights = { + "to_k_ref.weight": transformer_sd[layer_name + ".to_k.weight"], + "to_k_ref.bias": transformer_sd[layer_name + ".to_k.bias"], + "to_v_ref.weight": transformer_sd[layer_name + ".to_v.weight"], + "to_v_ref.bias": transformer_sd[layer_name + ".to_v.bias"], + } + attn_procs[name].load_state_dict(weights) + elif init_method == "random": + pass + else: + raise ValueError(f"{init_method} is not supported.") + block.attn1.processor = attn_procs[name] + ref_layers = torch.nn.ModuleList(attn_procs.values()) + ref_layers.to(device=model.device, dtype=dtype) + return model, ref_layers diff --git a/models/wan/lynx/inference_utils.py b/models/wan/lynx/inference_utils.py new file mode 100644 index 000000000..c1129f3bf --- /dev/null +++ b/models/wan/lynx/inference_utils.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Union + +import torch +import numpy as np + +from PIL import Image +from dataclasses import dataclass, field + + +dtype_mapping = { + "bf16": torch.bfloat16, + "fp16": torch.float16, + "fp32": torch.float32 +} + + +@dataclass +class SubjectInfo: + name: str = "" + image_pil: Optional[Image.Image] = None + landmarks: Optional[Union[np.ndarray, torch.Tensor]] = None + face_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None + + +@dataclass +class VideoStyleInfo: # key names should match those used in style.yaml file + style_name: str = 'none' + num_frames: int = 81 + seed: int = -1 + guidance_scale: float = 5.0 + guidance_scale_i: float = 2.0 + num_inference_steps: int = 50 + width: int = 832 + height: int = 480 + prompt: str = '' + negative_prompt: str = '' \ No newline at end of file diff --git a/models/wan/lynx/light_attention_processor.py b/models/wan/lynx/light_attention_processor.py new file mode 100644 index 000000000..a3c7516b1 --- /dev/null +++ b/models/wan/lynx/light_attention_processor.py @@ -0,0 +1,194 @@ +# Copyright (c) 2025 The Wan Team and The HuggingFace Team. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# +# Original file was released under Apache License 2.0, with the full license text +# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. +# +# This modified file is released under the same license. + +from typing import Tuple, Union, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from diffusers.models.normalization import RMSNorm + + +def register_ip_adapter_wan( + model, + hidden_size=5120, + cross_attention_dim=2048, + dtype=torch.float32, + init_method="zero", + layers=None +): + attn_procs = {} + transformer_sd = model.state_dict() + + if layers is None: + layers = list(range(0, len(model.blocks))) + elif isinstance(layers, int): # Only interval provided + layers = list(range(0, len(model.blocks), layers)) + + for i, block in enumerate(model.blocks): + if i not in layers: + continue + + name = f"blocks.{i}.attn2.processor" + attn_procs[name] = IPAWanAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim + ) + + if init_method == "zero": + torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight) + torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight) + elif init_method == "clone": + layer_name = name.split(".processor")[0] + weights = { + "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"], + "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"], + } + attn_procs[name].load_state_dict(weights) + else: + raise ValueError(f"{init_method} is not supported.") + + block.attn2.processor = attn_procs[name] + + ip_layers = torch.nn.ModuleList(attn_procs.values()) + ip_layers.to(model.device, dtype=dtype) + + return model, ip_layers + + +class IPAWanAttnProcessor2_0(torch.nn.Module): + def __init__( + self, + hidden_size, + cross_attention_dim=None, + scale=1.0, + bias=False, + ): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("IPAWanAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias) + + torch.nn.init.zeros_(self.to_k_ip.weight) + torch.nn.init.zeros_(self.to_v_ip.weight) + if bias: + torch.nn.init.zeros_(self.to_k_ip.bias) + torch.nn.init.zeros_(self.to_v_ip.bias) + + self.norm_rms_k = RMSNorm(hidden_size, eps=1e-5, elementwise_affine=False) + + + def __call__( + self, + attn, + hidden_states: torch.Tensor, + image_embed: torch.Tensor = None, + ip_scale: float = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + if attn.add_k_proj is not None: + encoder_hidden_states_img = encoder_hidden_states[:, :257] + encoder_hidden_states = encoder_hidden_states[:, 257:] + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + # ============================================================= + batch_size = image_embed.size(0) + + ip_hidden_states = image_embed + ip_query = query # attn.to_q(hidden_states.clone()) + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + if attn.norm_q is not None: + ip_query = attn.norm_q(ip_query) + ip_key = self.norm_rms_k(ip_key) + + ip_inner_dim = ip_key.shape[-1] + ip_head_dim = ip_inner_dim // attn.heads + + ip_query = ip_query.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2) + + ip_hidden_states = F.scaled_dot_product_attention( + ip_query, ip_key, ip_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * ip_head_dim) + ip_hidden_states = ip_hidden_states.to(ip_query.dtype) + # =========================================================================== + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if rotary_emb is not None: + + def apply_rotary_emb_inner(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + query = apply_rotary_emb_inner(query, rotary_emb) + key = apply_rotary_emb_inner(key, rotary_emb) + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + # Add IPA residual + ip_scale = ip_scale or self.scale + hidden_states = hidden_states + ip_scale * ip_hidden_states + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + return hidden_states \ No newline at end of file diff --git a/models/wan/lynx/model_utils_wan.py b/models/wan/lynx/model_utils_wan.py new file mode 100644 index 000000000..2fd117046 --- /dev/null +++ b/models/wan/lynx/model_utils_wan.py @@ -0,0 +1,264 @@ +# Copyright (c) 2025 The Wan Team and The HuggingFace Team. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 +# +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# +# Original file was released under Apache License 2.0, with the full license text +# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. +# +# This modified file is released under the same license. + +import os +import math +import torch + +from typing import List, Optional, Tuple, Union + +from diffusers.models.embeddings import get_3d_rotary_pos_embed +from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid + +from diffusers import FlowMatchEulerDiscreteScheduler + +""" Utility functions for Wan2.1-T2I model """ + + +def cal_mean_and_std(vae, device, dtype): + # https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py#L483 + + # latents_mean + latents_mean = torch.tensor(vae.config.latents_mean).view( + 1, vae.config.z_dim, 1, 1, 1).to(device, dtype) + + # latents_std + latents_std = 1.0 / torch.tensor(vae.config.latents_std).view( + 1, vae.config.z_dim, 1, 1, 1).to(device, dtype) + + return latents_mean, latents_std + + +def encode_video(vae, video): + # Input video is formatted as [B, F, C, H, W] + video = video.to(vae.device, dtype=vae.dtype) + video = video.unsqueeze(0) if len(video.shape) == 4 else video + + # VAE takes [B, C, F, H, W] + video = video.permute(0, 2, 1, 3, 4) + latent_dist = vae.encode(video).latent_dist + + mean, std = cal_mean_and_std(vae, vae.device, vae.dtype) + + # NOTE: this implementation is weird but I still borrow it from the official + # codes. A standard normailization should be (x - mean) / std, but the std + # here is computed by `1 / vae.config.latents_std` from cal_mean_and_std(). + + # Sample latent with scaling + video_latent = (latent_dist.sample() - mean) * std + + # Return as [B, C, F, H, W] as required by Wan + return video_latent + + +def decode_latent(vae, latent): + # Input latent is formatted as [B, C, F, H, W] + latent = latent.to(vae.device, dtype=vae.dtype) + latent = latent.unsqueeze(0) if len(latent.shape) == 4 else latent + + mean, std = cal_mean_and_std(vae, latent.device, latent.dtype) + latent = latent / std + mean + + # VAE takes [B, C, F, H, W] + frames = vae.decode(latent).sample + + # Return as [B, F, C, H, W] + frames = frames.permute(0, 2, 1, 3, 4).float() + + return frames + + +def encode_prompt( + tokenizer, + text_encoder, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 226, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + text_input_ids=None, + prompt_attention_mask=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if tokenizer is not None: + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_attention_mask = text_inputs.attention_mask + else: + if text_input_ids is None: + raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") + + seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() + prompt_embeds = text_encoder(text_input_ids.to(device), prompt_attention_mask.to(device)).last_hidden_state + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] + prompt_embeds = torch.stack( + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + ) + + # duplicate text embeddings for each generation per prompt, using mps friendly method + _, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return prompt_embeds, prompt_attention_mask + + +def compute_prompt_embeddings( + tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False +): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + return prompt_embeds[0] + + +def prepare_rotary_positional_embeddings( + height: int, + width: int, + num_frames: int, + vae_scale_factor_spatial: int = 8, + patch_size: int = 2, + attention_head_dim: int = 64, + device: Optional[torch.device] = None, + base_height: int = 480, + base_width: int = 720, +) -> Tuple[torch.Tensor, torch.Tensor]: + grid_height = height // (vae_scale_factor_spatial * patch_size) + grid_width = width // (vae_scale_factor_spatial * patch_size) + base_size_width = base_width // (vae_scale_factor_spatial * patch_size) + base_size_height = base_height // (vae_scale_factor_spatial * patch_size) + + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) + freqs_cos, freqs_sin = get_3d_rotary_pos_embed( + embed_dim=attention_head_dim, + crops_coords=grid_crops_coords, + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + ) + + freqs_cos = freqs_cos.to(device=device) + freqs_sin = freqs_sin.to(device=device) + return freqs_cos, freqs_sin + + +def prepare_sigmas( + scheduler: FlowMatchEulerDiscreteScheduler, + sigmas: torch.Tensor, + batch_size: int, + num_train_timesteps: int, + flow_weighting_scheme: str = "none", + flow_logit_mean: float = 0.0, + flow_logit_std: float = 1.0, + flow_mode_scale: float = 1.29, + device: torch.device = torch.device("cpu"), + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): + weights = compute_density_for_timestep_sampling( + weighting_scheme=flow_weighting_scheme, + batch_size=batch_size, + logit_mean=flow_logit_mean, + logit_std=flow_logit_std, + mode_scale=flow_mode_scale, + device=device, + generator=generator, + ) + indices = (weights * num_train_timesteps).long() + else: + raise ValueError(f"Unsupported scheduler type {type(scheduler)}") + + return sigmas[indices] + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = None, + logit_std: float = None, + mode_scale: float = None, + device: torch.device = torch.device("cpu"), + generator: Optional[torch.Generator] = None, +) -> torch.Tensor: + r""" + Compute the density for sampling the timesteps when doing SD3 training. + + Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. + + SD3 paper reference: https://arxiv.org/abs/2403.03206v1. + """ + if weighting_scheme == "logit_normal": + # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) + u = torch.nn.functional.sigmoid(u) + elif weighting_scheme == "mode": + u = torch.rand(size=(batch_size,), device=device, generator=generator) + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + else: + u = torch.rand(size=(batch_size,), device=device, generator=generator) + return u + + +def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor: + assert len(tensor.shape) <= ndim + return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape))) + + +def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + r"""Forward process of flow matching.""" + return (1.0 - t) * x0 + t * n + + +def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: + r"""Loss target for flow matching.""" + return n - x0 + + +def prepare_loss_weights( + scheduler: FlowMatchEulerDiscreteScheduler, + sigmas: Optional[torch.Tensor] = None, + flow_weighting_scheme: str = "none", + +) -> torch.Tensor: + + assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler) + + from diffusers.training_utils import compute_loss_weighting_for_sd3 + + return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme) \ No newline at end of file diff --git a/models/wan/lynx/navit_utils.py b/models/wan/lynx/navit_utils.py new file mode 100644 index 000000000..593e7f0d1 --- /dev/null +++ b/models/wan/lynx/navit_utils.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import torch + +def vector_to_list(tensor, lens, dim): + return list(torch.split(tensor, lens, dim=dim)) + +def list_to_vector(tensor_list, dim): + lens = [tensor.shape[dim] for tensor in tensor_list] + tensor = torch.cat(tensor_list, dim) + return tensor, lens + +def merge_token_lists(list1, list2, dim): + assert(len(list1) == len(list2)) + return [torch.cat((t1, t2), dim) for t1, t2 in zip(list1, list2)] \ No newline at end of file diff --git a/models/wan/lynx/resampler.py b/models/wan/lynx/resampler.py new file mode 100644 index 000000000..e225a062b --- /dev/null +++ b/models/wan/lynx/resampler.py @@ -0,0 +1,185 @@ + +# Copyright (c) 2022 Phil Wang +# Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt. +# Copyright (c) 2023 Tencent AI Lab +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates + +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# SPDX-License-Identifier: Apache-2.0 + +# Original file (IP-Adapter) was released under Apache License 2.0, with the full license text +# available at https://github.com/tencent-ailab/IP-Adapter/blob/main/LICENSE. + +# Original file (open_flamingo and flamingo-pytorch) was released under MIT License: + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence + **kwargs, + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.pretrained = kwargs.pop("pretrained", None) + self.freeze = kwargs.pop("freeze", False) + self.skip_last_layer = kwargs.pop("skip_last_layer", False) + + if self.freeze and self.pretrained is None: + raise ValueError("Trying to freeze the model but no pretrained provided!") + + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 82a313cec..42af87737 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -355,13 +355,38 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks class WanT2VCrossAttention(WanSelfAttention): - def forward(self, xlist, context, grid_sizes, *args, **kwargs): + def forward(self, xlist, context, grid_sizes, lynx_ip_embeds = None, lynx_ip_scale = 0, *args, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - x, _ = self.text_cross_attention( xlist, context) + x, q = self.text_cross_attention( xlist, context, return_q=lynx_ip_embeds is not None) + if lynx_ip_embeds is not None and self.to_k_ip is not None: + if self.registers is not None: + from ..lynx.navit_utils import vector_to_list, merge_token_lists, list_to_vector + ip_hidden_states_list = vector_to_list(lynx_ip_embeds, lynx_ip_embeds.shape[1], 1) + ip_hidden_states_list = merge_token_lists(ip_hidden_states_list, [self.registers] * len(ip_hidden_states_list), 1) + lynx_ip_embeds, ip_lens = list_to_vector(ip_hidden_states_list, 1) + ip_hidden_states_list = None + ip_key = self.to_k_ip(lynx_ip_embeds) + ip_value = self.to_v_ip(lynx_ip_embeds) + # if self.norm_q is not None: ip_query = self.norm_q(q) + if self.norm_rms_k is None: + ip_key = self.norm_k(ip_key) + else: + ip_key = self.norm_rms_k(ip_key) + ip_inner_dim = ip_key.shape[-1] + ip_head_dim = ip_inner_dim // self.num_heads + batch_size = q.shape[0] + # ip_query = ip_query.view(batch_size, -1, attn.heads, ip_head_dim) + ip_key = ip_key.view(batch_size, -1, self.num_heads, ip_head_dim) + ip_value = ip_value.view(batch_size, -1, self.num_heads, ip_head_dim) + qkv_list = [q, ip_key, ip_value] + del q, ip_key, ip_value + ip_hidden_states = pay_attention(qkv_list).reshape(*x.shape) + x.add_(ip_hidden_states, alpha= lynx_ip_scale) + x = self.o(x) return x @@ -382,7 +407,7 @@ def __init__(self, # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ): + def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens, *args, **kwargs ): r""" Args: x(Tensor): Shape [B, L1, C] @@ -510,6 +535,8 @@ def forward( ref_images_count=0, standin_phase=-1, motion_vec = None, + lynx_ip_embeds = None, + lynx_ip_scale = 0, ): r""" Args: @@ -568,7 +595,7 @@ def forward( y = y.to(attention_dtype) ylist= [y] del y - x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype) + x += self.cross_attn(ylist, context, grid_sizes, audio_proj = audio_proj, audio_scale = audio_scale, audio_context_lens = audio_context_lens, lynx_ip_embeds=lynx_ip_embeds, lynx_ip_scale=lynx_ip_scale).to(dtype) if multitalk_audio != None: # cross attn of multitalk audio @@ -914,6 +941,7 @@ def __init__(self, norm_output_audio=True, standin= False, motion_encoder_dim=0, + lynx=None, ): super().__init__() @@ -1044,6 +1072,32 @@ def __init__(self, block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) + if lynx is not None: + from ..lynx.light_attention_processor import IPAWanAttnProcessor2_0 + lynx_full = lynx=="full" + if lynx_full: + lynx_cross_dim = 5120 + lynx_layers = len(self.blocks) + else: + lynx_cross_dim = 2048 + lynx_layers = 20 + for i, block in enumerate(self.blocks): + if i < lynx_layers: + attn = IPAWanAttnProcessor2_0(hidden_size=dim, cross_attention_dim = lynx_cross_dim,) + block.cross_attn.to_k_ip = attn.to_k_ip + block.cross_attn.to_v_ip = attn.to_v_ip + else: + block.cross_attn.to_k_ip = None + block.cross_attn.to_v_ip = None + if lynx_full: + block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5) + block.cross_attn.norm_rms_k = None + else: + block.cross_attn.registers = None + block.cross_attn.norm_rms_k = attn.norm_rms_k + + + if animate: self.pose_patch_embedding = nn.Conv3d( 16, dim, kernel_size=patch_size, stride=patch_size @@ -1247,6 +1301,8 @@ def forward( standin_ref = None, pose_latents=None, face_pixel_values=None, + lynx_ip_embeds = None, + lynx_ip_scale = 0, ): # patch_dtype = self.patch_embedding.weight.dtype @@ -1286,11 +1342,11 @@ def forward( y = None motion_vec_list = [] - for i, x in enumerate(x_list): + for i, (x, one_face_pixel_values) in enumerate(zip(x_list, face_pixel_values)): # animate embeddings motion_vec = None if pose_latents is not None: - x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values) + x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(one_face_pixel_values) if one_face_pixel_values is None else one_face_pixel_values) motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) @@ -1330,6 +1386,7 @@ def forward( audio_proj=audio_proj, audio_context_lens=audio_context_lens, ref_images_count=ref_images_count, + lynx_ip_scale= lynx_ip_scale, ) _flag_df = t.dim() == 2 @@ -1348,6 +1405,11 @@ def forward( standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) ) standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype) standin_e = standin_ref = None + + if lynx_ip_embeds is None: + lynx_ip_embeds_list = [None] * len(x_list) + else: + lynx_ip_embeds_list = lynx_ip_embeds if self.inject_sample_info and fps!=None: fps = torch.tensor(fps, dtype=torch.long, device=device) @@ -1498,9 +1560,9 @@ def forward( continue x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) else: - for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list)): + for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec, lynx_ip_embeds) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list, lynx_ip_embeds_list)): if should_calc: - x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec,**kwargs) + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec, lynx_ip_embeds= lynx_ip_embeds, **kwargs) del x context = hints = audio_embedding = None diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 1a30fffbc..e1829f03a 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -17,6 +17,9 @@ def test_multitalk(base_model_type): def test_standin(base_model_type): return base_model_type in ["standin", "vace_standin_14B"] +def test_lynx(base_model_type): + return base_model_type in ["lynx_lite", "vace_lynx_lite_14B", "lynx_full", "vace_lynx_full_14B"] + def test_wan_5B(base_model_type): return base_model_type in ["ti2v_2_2", "lucy_edit"] class family_handler(): @@ -82,6 +85,7 @@ def query_model_def(base_model_type, model_def): extra_model_def["i2v_class"] = i2v extra_model_def["multitalk_class"] = test_multitalk(base_model_type) extra_model_def["standin_class"] = test_standin(base_model_type) + extra_model_def["lynx_class"] = test_lynx(base_model_type) vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] extra_model_def["vace_class"] = vace_class @@ -170,15 +174,26 @@ def query_model_def(base_model_type, model_def): "choices":[ ("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"), ("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"), - ("Replace Person in Control Video by Person in Reference Image", "PVBAI"), - ("Replace Person in Control Video by Person in Reference Image and Apply Relighting Process", "PVBAI1"), + ("Replace Person in Control Video by Person in Ref Image", "PVBAIH#"), + ("Replace Person in Control Video by Person in Ref Image. See Through Mask", "PVBAI#"), ], "default": "PVBKI", - "letters_filter": "PVBXAKI1", + "letters_filter": "PVBXAKIH#", "label": "Type of Process", + "scale": 3, "show_label" : False, } + extra_model_def["custom_video_selection"] = { + "choices":[ + ("None", ""), + ("Apply Relighting", "1"), + ], + "label": "Custom Process", + "show_label" : False, + "scale": 1, + } + extra_model_def["mask_preprocessing"] = { "selection":[ "", "A", "XA"], "visible": False @@ -229,7 +244,7 @@ def query_model_def(base_model_type, model_def): extra_model_def["forced_guide_mask_inputs"] = True extra_model_def["return_image_refs_tensor"] = True - if base_model_type in ["standin"]: + if base_model_type in ["standin", "lynx_lite", "lynx_full"]: extra_model_def["fit_into_canvas_image_refs"] = 0 extra_model_def["image_ref_choices"] = { "choices": [ @@ -298,7 +313,7 @@ def query_model_def(base_model_type, model_def): @staticmethod def query_supported_types(): return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", - "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", + "t2v_1.3B", "standin", "lynx_lite", "lynx_full", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", "recam_1.3B", "animate", "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] @@ -312,8 +327,8 @@ def query_family_maps(): } models_comp_map = { - "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B"], - "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin"], + "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_full_14B"], + "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx_full"], "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], "i2v_2_2" : ["i2v_2_2_multitalk"], "fantasy": ["multitalk"], @@ -423,7 +438,7 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): ui_defaults["video_prompt_type"] = video_prompt_type if settings_version < 2.31: - if base_model_type in "recam_1.3B": + if base_model_type in ["recam_1.3B"]: video_prompt_type = ui_defaults.get("video_prompt_type", "") if not "V" in video_prompt_type: video_prompt_type += "UV" @@ -438,6 +453,14 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): if test_class_i2v(base_model_type) and len(image_prompt_type) == 0: ui_defaults["image_prompt_type"] = "S" + + if settings_version < 2.37: + if base_model_type in ["animate"]: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "1" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("1", "#1") + ui_defaults["video_prompt_type"] = video_prompt_type + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ @@ -481,6 +504,15 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "video_prompt_type": "I", "remove_background_images_ref" : 1, }) + elif base_model_type in ["lynx_lite", "lynx_full"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "sliding_window_overlap" : 9, + "video_prompt_type": "QI", + "remove_background_images_ref" : 1, + }) + elif base_model_type in ["phantom_1.3B", "phantom_14B"]: ui_defaults.update({ "guidance_scale": 7.5, diff --git a/preprocessing/arc/face_encoder.py b/preprocessing/arc/face_encoder.py new file mode 100644 index 000000000..0fb916219 --- /dev/null +++ b/preprocessing/arc/face_encoder.py @@ -0,0 +1,97 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torchvision.transforms as T + +import numpy as np + +from insightface.utils import face_align +from insightface.app import FaceAnalysis +from facexlib.recognition import init_recognition_model + + +__all__ = [ + "FaceEncoderArcFace", + "get_landmarks_from_image", +] + + +detector = None + +def get_landmarks_from_image(image): + """ + Detect landmarks with insightface. + + Args: + image (np.ndarray or PIL.Image): + The input image in RGB format. + + Returns: + 5 2D keypoints, only one face will be returned. + """ + global detector + if detector is None: + detector = FaceAnalysis() + detector.prepare(ctx_id=0, det_size=(640, 640)) + + in_image = np.array(image).copy() + + faces = detector.get(in_image) + if len(faces) == 0: + raise ValueError("No face detected in the image") + + # Get the largest face + face = max(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) + + # Return the 5 keypoints directly + keypoints = face.kps # 5 x 2 + + return keypoints + + +class FaceEncoderArcFace(): + """ Official ArcFace, no_grad-only """ + + def __repr__(self): + return "ArcFace" + + + def init_encoder_model(self, device, eval_mode=True): + self.device = device + self.encoder_model = init_recognition_model('arcface', device=device) + + if eval_mode: + self.encoder_model.eval() + + + @torch.no_grad() + def input_preprocessing(self, in_image, landmarks, image_size=112): + assert landmarks is not None, "landmarks are not provided!" + + in_image = np.array(in_image) + landmark = np.array(landmarks) + + face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=image_size) + + image_transform = T.Compose([ + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ]) + face_aligned = image_transform(face_aligned).unsqueeze(0).to(self.device) + + return face_aligned + + + @torch.no_grad() + def __call__(self, in_image, need_proc=False, landmarks=None, image_size=112): + + if need_proc: + in_image = self.input_preprocessing(in_image, landmarks, image_size) + else: + assert isinstance(in_image, torch.Tensor) + + in_image = in_image[:, [2, 1, 0], :, :].contiguous() + image_embeds = self.encoder_model(in_image) # [B, 512], normalized + + return image_embeds \ No newline at end of file diff --git a/preprocessing/arc/face_utils.py b/preprocessing/arc/face_utils.py new file mode 100644 index 000000000..41a454546 --- /dev/null +++ b/preprocessing/arc/face_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 Insightface Team +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates + +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# SPDX-License-Identifier: Apache-2.0 + +# Original file (insightface) was released under MIT License: + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import numpy as np +import cv2 +from PIL import Image + +def estimate_norm(lmk, image_size=112, arcface_dst=None): + from skimage import transform as trans + assert lmk.shape == (5, 2) + assert image_size%112==0 or image_size%128==0 + if image_size%112==0: + ratio = float(image_size)/112.0 + diff_x = 0 + else: + ratio = float(image_size)/128.0 + diff_x = 8.0*ratio + dst = arcface_dst * ratio + dst[:,0] += diff_x + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + +def get_arcface_dst(extend_face_crop=False, extend_ratio=0.8): + arcface_dst = np.array( + [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], + [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) + if extend_face_crop: + arcface_dst[:,1] = arcface_dst[:,1] + 10 + arcface_dst = (arcface_dst - 112/2) * extend_ratio + 112/2 + return arcface_dst + +def align_face(image_pil, face_kpts, extend_face_crop=False, extend_ratio=0.8, face_size=112): + arcface_dst = get_arcface_dst(extend_face_crop, extend_ratio) + M = estimate_norm(face_kpts, face_size, arcface_dst) + image_cv2 = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) + face_image_cv2 = cv2.warpAffine(image_cv2, M, (face_size, face_size), borderValue=0.0) + face_image = Image.fromarray(cv2.cvtColor(face_image_cv2, cv2.COLOR_BGR2RGB)) + return face_image \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d6f75f74b..e567f3ae4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,6 +35,8 @@ rembg[gpu]==2.0.65 onnxruntime-gpu decord timm +# insightface==0.7.3 +# facexlib==0.3.0 # Config & orchestration omegaconf diff --git a/wgp.py b/wgp.py index 6a257e814..9143689b8 100644 --- a/wgp.py +++ b/wgp.py @@ -63,8 +63,8 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.75" -settings_version = 2.36 +WanGP_version = "8.76" +settings_version = 2.37 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -3736,7 +3736,7 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr alpha_mask[mask > 127] = 1 frame = frame * alpha_mask frame = Image.fromarray(frame) - face = face_processor.process(frame, resize_to=size, face_crop_scale = 1) + face = face_processor.process(frame, resize_to=size, face_crop_scale = 1.2,) face_list.append(face) face_processor = None @@ -3863,7 +3863,7 @@ def prep_prephase(frame_idx): mask = op_expand(mask, kernel, iterations=3) _, mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY) - if to_bbox and np.sum(mask == 255) > 0: + if to_bbox and np.sum(mask == 255) > 0 : #or True x0, y0, x1, y1 = mask_to_xyxy_box(mask) mask = mask * 0 mask[y0:y1, x0:x1] = 255 @@ -4766,7 +4766,7 @@ def remove_temp_filenames(temp_filenames_list): raise Exception(f"unknown cache type {skip_steps_cache_type}") trans.cache = skip_steps_cache if trans2 is not None: trans2.cache = skip_steps_cache - + face_arc_embeds = None output_new_audio_data = None output_new_audio_filepath = None original_audio_guide = audio_guide @@ -5039,9 +5039,9 @@ def remove_temp_filenames(temp_filenames_list): context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight] if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)]) inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size ) + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size, to_bbox = "H" in video_prompt_type ) if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size ) + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size, to_bbox = "H" in video_prompt_type ) if video_guide_processed is not None and sample_fit_canvas is not None: image_size = video_guide_processed.shape[-2:] @@ -5069,6 +5069,17 @@ def remove_temp_filenames(temp_filenames_list): if len(image_refs) > nb_frames_positions: src_ref_images = image_refs[nb_frames_positions:] + if "Q" in video_prompt_type: + from preprocessing.arc.face_encoder import FaceEncoderArcFace, get_landmarks_from_image + image_pil = src_ref_images[-1] + face_encoder = FaceEncoderArcFace() + face_encoder.init_encoder_model(processing_device) + face_arc_embeds = face_encoder(image_pil, need_proc=True, landmarks=get_landmarks_from_image(image_pil)) + face_arc_embeds = face_arc_embeds.squeeze(0).cpu() + face_encoder = image_pil = None + gc.collect() + torch.cuda.empty_cache() + if remove_background_images_ref > 0: send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) @@ -5079,7 +5090,6 @@ def remove_temp_filenames(temp_filenames_list): outpainting_dims =outpainting_dims, background_ref_outpainted = model_def.get("background_ref_outpainted", True), return_tensor= model_def.get("return_image_refs_tensor", False) ) - frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): @@ -5241,6 +5251,7 @@ def set_header_text(txt): original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, outpainting_dims = outpainting_dims, + face_arc_embeds = face_arc_embeds, ) except Exception as e: if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: @@ -6980,7 +6991,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) return video_prompt_type -all_guide_processes ="PDESLCMUVB" +all_guide_processes ="PDESLCMUVBH" def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): model_type = state["model_type"] @@ -7006,7 +7017,34 @@ def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, else: mask_selector_visible = True ref_images_visible = "I" in video_prompt_type - return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ) + custom_options = custom_checkbox= False + if "#" in video_prompt_type: + custom_options = True + custom_video_choices = model_def.get("custom_video_selection", None) + if custom_video_choices is not None: + custom_checkbox = len(custom_video_choices.get("choices", [])) <= 2 + + return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ), gr.update(visible= custom_options and not custom_checkbox ), gr.update(visible= custom_options and custom_checkbox ) + +def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, video_prompt_type_video_custom_dropbox): + model_type = state["model_type"] + model_def = get_model_def(model_type) + custom_video_choices = model_def.get("custom_video_selection", None) + if not "#" in video_prompt_type or custom_video_choices is None: return gr.update() + video_prompt_type = del_in_sequence(video_prompt_type, "0123456789") + video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox) + return video_prompt_type + +def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, video_prompt_type_video_custom_checkbox): + model_type = state["model_type"] + model_def = get_model_def(model_type) + custom_video_choices = model_def.get("custom_video_selection", None) + if not "#" in video_prompt_type or custom_video_choices is None: return gr.update() + video_prompt_type = del_in_sequence(video_prompt_type, "0123456789") + if video_prompt_type_video_custom_checkbox: + video_prompt_type = add_to_sequence(video_prompt_type, custom_video_choices["choices"][1][1]) + return video_prompt_type + # def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): @@ -7474,7 +7512,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl else: image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False) if "E" in image_prompt_types_allowed: - image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1) + image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1, elem_classes="cbx_centered") any_end_image = True else: image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1) @@ -7539,14 +7577,14 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl video_prompt_type_video_guide = gr.Dropdown( guide_preprocessing_choices, value=filter_letters(video_prompt_type_value, all_guide_processes, guide_preprocessing.get("default", "") ), - label= video_prompt_type_video_guide_label , scale = 2, visible= guide_preprocessing.get("visible", True) , show_label= True, + label= video_prompt_type_video_guide_label , scale = 1, visible= guide_preprocessing.get("visible", True) , show_label= True, ) any_control_video = True any_control_image = image_outputs # Alternate Control Video Preprocessing / Options if guide_custom_choices is None: - video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 2 ) + video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 1 ) else: video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process") if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image") @@ -7557,14 +7595,37 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl # value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ), value=guide_custom_choices_value, visible = guide_custom_choices.get("visible", True), - label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = 2 + label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = guide_custom_choices.get("scale", 1), ) any_control_video = True any_control_image = image_outputs + # Custom dropdown box & checkbox + custom_video_selection = model_def.get("custom_video_selection", None) + custom_checkbox= False + if custom_video_selection is None: + video_prompt_type_video_custom_dropbox = gr.Dropdown(choices=[("","")], value="", label="Custom Dropdown", scale = 1, visible= False, show_label= True, ) + video_prompt_type_video_custom_checkbox = gr.Checkbox(value=False, label="Custom Checkbbox", scale = 1, visible= False, show_label= True, ) + + else: + custom_choices = "#" in video_prompt_type_value + custom_video_choices = custom_video_selection["choices"] + custom_checkbox = len(custom_video_choices) <= 2 + + video_prompt_type_video_custom_label = custom_video_selection.get("label", "Custom Choices") + + video_prompt_type_video_custom_dropbox = gr.Dropdown( + custom_video_choices, + value=filter_letters(video_prompt_type_value, "0123456789", custom_video_selection.get("default", "")), + scale = custom_video_selection.get("scale", 1), + label= video_prompt_type_video_custom_label , visible= not custom_checkbox and custom_choices, + show_label= custom_video_selection.get("show_label", True), + ) + video_prompt_type_video_custom_checkbox = gr.Checkbox(value= custom_video_choices[1][1] in video_prompt_type_value , label=custom_video_choices[1][0] , scale = custom_video_selection.get("scale", 1), visible=custom_checkbox and custom_choices, show_label= True, elem_classes="cbx_centered" ) + # Control Mask Preprocessing if mask_preprocessing is None: - video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 2, visible= False, show_label= True, ) + video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 1, visible= False, show_label= True, ) any_image_mask = image_outputs else: mask_preprocessing_labels_all = { @@ -7592,7 +7653,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl video_prompt_type_video_mask = gr.Dropdown( mask_preprocessing_choices, value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), - label= video_prompt_type_video_mask_label , scale = 2, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True), + label= video_prompt_type_video_mask_label , scale = 1, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True), show_label= True, ) @@ -7604,7 +7665,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl choices=[ ("None", ""),], value=filter_letters(video_prompt_type_value, ""), visible = False, - label="Start / Reference Images", scale = 2 + label="Start / Reference Images", scale = 1 ) any_reference_image = False else: @@ -7613,7 +7674,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl choices= image_ref_choices["choices"], value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), visible = image_ref_choices.get("visible", True), - label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 2 + label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 1 ) image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) @@ -8287,7 +8348,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, image_prompt_type_group, image_prompt_type_radio, image_prompt_type_endcheckbox, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab, sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col, - video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row, + video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox, + apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, @@ -8316,9 +8378,12 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" ) image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] ) video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden") - video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden") - video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row], show_progress="hidden") + video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox], show_progress="hidden") + video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox], show_progress="hidden") # video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength, video_mask, mask_expand, image_mask_guide, image_guide, image_mask, keep_frames_video_guide ], show_progress="hidden") + video_prompt_type_video_custom_dropbox.input(fn= refresh_video_prompt_type_video_custom_dropbox, inputs=[state, video_prompt_type, video_prompt_type_video_custom_dropbox], outputs = video_prompt_type) + video_prompt_type_video_custom_checkbox.input(fn= refresh_video_prompt_type_video_custom_checkbox, inputs=[state, video_prompt_type, video_prompt_type_video_custom_checkbox], outputs = video_prompt_type) + video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden") video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden") @@ -9587,6 +9652,7 @@ def create_ui(): transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* 1s delay before showing */ } .btn_centered {margin-top:10px; text-wrap-mode: nowrap;} + .cbx_centered label {margin-top:8px; text-wrap-mode: nowrap;} """ UI_theme = server_config.get("UI_theme", "default") UI_theme = args.theme if len(args.theme) > 0 else UI_theme From 5ad0cd4ac65406b5b1785e06bbbcfb21f1ca579a Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 18:20:57 +0200 Subject: [PATCH 049/155] updated release notes --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 939cb7f04..72a7b0df8 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### September 25 2025: WanGP v8.73 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release +### September 25 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. @@ -29,6 +29,8 @@ In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise) +For those of you who have a mask halo effect when Animating a character I recommend trying *SDPA attention* and to use the *FusioniX i2v* lora. If this issue persists (this will depend on the control video) you have now a choice of the two *Animate Mask Options* in *WanGP 8.76*. The old masking option which was a WanGP exclusive has been renamed *See Through Mask* because the background behind the animated character was preserved but this creates sometime visual artifacts. The new option which has the shorter name is what you may find elsewhere online. As it uses internally a much larger mask, there is no halo. However the immediate background behind the character is not preserved and may end completely different. + - **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. Also because I wanted to spoil you: @@ -41,6 +43,8 @@ Also because I wanted to spoil you: *Update 8.72*: shadow drop of Qwen Edit Plus *Update 8.73*: Qwen Preview & InfiniteTalk Start image *Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video +*Update 8.75*: REDACTED +*Update 8.76*: Alternate Animate masking that fixes the mask halo effect that some users have ### September 15 2025: WanGP v8.6 - Attack of the Clones From 4a84bdecf7db9a02f68961d7c4799bf9d77b089f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 18:41:44 +0200 Subject: [PATCH 050/155] oops --- models/wan/modules/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 42af87737..5ed2e1724 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -1342,6 +1342,7 @@ def forward( y = None motion_vec_list = [] + if face_pixel_values is None: face_pixel_values = [None] * len(x_list) for i, (x, one_face_pixel_values) in enumerate(zip(x_list, face_pixel_values)): # animate embeddings motion_vec = None From 6a913c7ecf3f904edb3bc99928fb0eb42fc460ae Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 18:49:36 +0200 Subject: [PATCH 051/155] oops --- models/wan/modules/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 5ed2e1724..c2419341e 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -1347,7 +1347,7 @@ def forward( # animate embeddings motion_vec = None if pose_latents is not None: - x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(one_face_pixel_values) if one_face_pixel_values is None else one_face_pixel_values) + x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values) motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) From 6bbb60f9e8f25cc8fa28a649f6040968605c5adf Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 19:56:54 +0200 Subject: [PATCH 052/155] rollback cfg for animate --- models/wan/modules/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index c2419341e..07e3e54e9 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -1347,7 +1347,8 @@ def forward( # animate embeddings motion_vec = None if pose_latents is not None: - x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values) + # x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values) + x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values[0]) # if one_face_pixel_values is None else one_face_pixel_values) motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) From e7433804d1e61315034fdff7ba22a6a77bf9aaa4 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 21:47:58 +0200 Subject: [PATCH 053/155] fixed infinitetalk bug --- models/wan/any2video.py | 2 +- models/wan/wan_handler.py | 20 ++++++++++++++------ wgp.py | 23 ++++++++++++----------- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index b0fc24659..9f27fb30a 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -447,7 +447,7 @@ def generate(self, if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False if infinitetalk: - new_shot = "Q" in video_prompt_type + new_shot = "0" in video_prompt_type if input_frames is not None: image_ref = input_frames[:, 0] else: diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index e1829f03a..d3dc31e09 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -147,15 +147,15 @@ def query_model_def(base_model_type, model_def): extra_model_def["all_image_refs_are_background_ref"] = True extra_model_def["guide_custom_choices"] = { "choices":[ - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"), + ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "0KI"), ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QRUV"), + ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "0RUV"), ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"), - ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "GQUV"), + ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "G0UV"), ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"), ], "default": "KI", - "letters_filter": "RGUVQKI", + "letters_filter": "RGUV0KI", "label": "Video to Video", "show_label" : False, } @@ -190,6 +190,7 @@ def query_model_def(base_model_type, model_def): ("Apply Relighting", "1"), ], "label": "Custom Process", + "letters_filter": "1", "show_label" : False, "scale": 1, } @@ -427,7 +428,7 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): ui_defaults["audio_guidance_scale"]= 1 video_prompt_type = ui_defaults.get("video_prompt_type", "") if "I" in video_prompt_type: - video_prompt_type = video_prompt_type.replace("KI", "QKI") + video_prompt_type = video_prompt_type.replace("KI", "0KI") ui_defaults["video_prompt_type"] = video_prompt_type if settings_version < 2.28: @@ -461,6 +462,13 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): video_prompt_type = video_prompt_type.replace("1", "#1") ui_defaults["video_prompt_type"] = video_prompt_type + if settings_version < 2.38: + if base_model_type in ["infinitetalk"]: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "Q" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("Q", "0") + ui_defaults["video_prompt_type"] = video_prompt_type + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ @@ -491,7 +499,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "sliding_window_overlap" : 9, "sliding_window_size": 81, "sample_solver" : "euler", - "video_prompt_type": "QKI", + "video_prompt_type": "0KI", "remove_background_images_ref" : 0, "adaptive_switch" : 1, }) diff --git a/wgp.py b/wgp.py index 9143689b8..d03807b1f 100644 --- a/wgp.py +++ b/wgp.py @@ -63,8 +63,8 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.76" -settings_version = 2.37 +WanGP_version = "8.761" +settings_version = 2.38 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -7029,20 +7029,22 @@ def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, video_prompt_type_video_custom_dropbox): model_type = state["model_type"] model_def = get_model_def(model_type) - custom_video_choices = model_def.get("custom_video_selection", None) - if not "#" in video_prompt_type or custom_video_choices is None: return gr.update() - video_prompt_type = del_in_sequence(video_prompt_type, "0123456789") + custom_video_selection = model_def.get("custom_video_selection", None) + if not "#" in video_prompt_type or custom_video_selection is None: return gr.update() + letters_filter = custom_video_selection.get("letters_filter", "") + video_prompt_type = del_in_sequence(video_prompt_type, letters_filter) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox) return video_prompt_type def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, video_prompt_type_video_custom_checkbox): model_type = state["model_type"] model_def = get_model_def(model_type) - custom_video_choices = model_def.get("custom_video_selection", None) - if not "#" in video_prompt_type or custom_video_choices is None: return gr.update() - video_prompt_type = del_in_sequence(video_prompt_type, "0123456789") + custom_video_selection = model_def.get("custom_video_selection", None) + if not "#" in video_prompt_type or custom_video_selection is None: return gr.update() + letters_filter = custom_video_selection.get("letters_filter", "") + video_prompt_type = del_in_sequence(video_prompt_type, letters_filter) if video_prompt_type_video_custom_checkbox: - video_prompt_type = add_to_sequence(video_prompt_type, custom_video_choices["choices"][1][1]) + video_prompt_type = add_to_sequence(video_prompt_type, custom_video_selection["choices"][1][1]) return video_prompt_type @@ -7613,10 +7615,9 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl custom_checkbox = len(custom_video_choices) <= 2 video_prompt_type_video_custom_label = custom_video_selection.get("label", "Custom Choices") - video_prompt_type_video_custom_dropbox = gr.Dropdown( custom_video_choices, - value=filter_letters(video_prompt_type_value, "0123456789", custom_video_selection.get("default", "")), + value=filter_letters(video_prompt_type_value, custom_video_selection.get("letters_filter", ""), custom_video_selection.get("default", "")), scale = custom_video_selection.get("scale", 1), label= video_prompt_type_video_custom_label , visible= not custom_checkbox and custom_choices, show_label= custom_video_selection.get("show_label", True), From f77ba0317f0163e925851e9f19e23da7ef6d783b Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 22:44:17 +0200 Subject: [PATCH 054/155] UI improvements --- README.md | 2 +- models/wan/any2video.py | 2 +- models/wan/wan_handler.py | 27 ++++++++++++++++++++------- wgp.py | 22 ++++++++++++---------- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 72a7b0df8..ee432d770 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### September 25 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release +### September 28 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: - **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 9f27fb30a..c9b1a29a4 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -1058,7 +1058,7 @@ def adapt_animate_model(self, model): def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs): if base_model_type == "animate": - if "1" in video_prompt_type: + if "#" in video_prompt_type and "1" in video_prompt_type: preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") if len(preloadURLs) > 0: return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index d3dc31e09..b1f594271 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -147,19 +147,30 @@ def query_model_def(base_model_type, model_def): extra_model_def["all_image_refs_are_background_ref"] = True extra_model_def["guide_custom_choices"] = { "choices":[ - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "0KI"), - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "0RUV"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "RUV"), - ("Video to Video, amount of motion transferred depends on Denoising Strength - Sharp Transitions", "G0UV"), - ("Video to Video, amount of motion transferred depends on Denoising Strength - Smooth Transitions", "GUV"), + ("Images to Video, each Reference Image will start a new shot with a new Sliding Window", "KI"), + ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window", "RUV"), + ("Video to Video, amount of motion transferred depends on Denoising Strength", "GUV"), ], "default": "KI", - "letters_filter": "RGUV0KI", + "letters_filter": "RGUVKI", "label": "Video to Video", + "scale": 3, + "show_label" : False, + } + + extra_model_def["custom_video_selection"] = { + "choices":[ + ("Smooth Transitions", ""), + ("Sharp Transitions", "0"), + ], + "trigger": "", + "label": "Custom Process", + "letters_filter": "0", "show_label" : False, + "scale": 1, } + # extra_model_def["at_least_one_image_ref_needed"] = True if base_model_type in ["lucy_edit"]: extra_model_def["keep_frames_video_guide_not_supported"] = True @@ -189,7 +200,9 @@ def query_model_def(base_model_type, model_def): ("None", ""), ("Apply Relighting", "1"), ], + "trigger": "#", "label": "Custom Process", + "type": "checkbox", "letters_filter": "1", "show_label" : False, "scale": 1, diff --git a/wgp.py b/wgp.py index d03807b1f..e48070ca3 100644 --- a/wgp.py +++ b/wgp.py @@ -7017,12 +7017,13 @@ def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, else: mask_selector_visible = True ref_images_visible = "I" in video_prompt_type - custom_options = custom_checkbox= False - if "#" in video_prompt_type: - custom_options = True - custom_video_choices = model_def.get("custom_video_selection", None) - if custom_video_choices is not None: - custom_checkbox = len(custom_video_choices.get("choices", [])) <= 2 + custom_options = custom_checkbox = False + custom_video_selection = model_def.get("custom_video_selection", None) + if custom_video_selection is not None: + custom_trigger = custom_video_selection.get("trigger","") + if len(custom_trigger) == 0 or custom_trigger in video_prompt_type: + custom_options = True + custom_checkbox = custom_video_selection.get("type","") == "checkbox" return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ), gr.update(visible= custom_options and not custom_checkbox ), gr.update(visible= custom_options and custom_checkbox ) @@ -7030,7 +7031,7 @@ def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, vid model_type = state["model_type"] model_def = get_model_def(model_type) custom_video_selection = model_def.get("custom_video_selection", None) - if not "#" in video_prompt_type or custom_video_selection is None: return gr.update() + if custom_video_selection is None: return gr.update() letters_filter = custom_video_selection.get("letters_filter", "") video_prompt_type = del_in_sequence(video_prompt_type, letters_filter) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox) @@ -7040,7 +7041,7 @@ def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, vi model_type = state["model_type"] model_def = get_model_def(model_type) custom_video_selection = model_def.get("custom_video_selection", None) - if not "#" in video_prompt_type or custom_video_selection is None: return gr.update() + if custom_video_selection is None: return gr.update() letters_filter = custom_video_selection.get("letters_filter", "") video_prompt_type = del_in_sequence(video_prompt_type, letters_filter) if video_prompt_type_video_custom_checkbox: @@ -7610,9 +7611,10 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl video_prompt_type_video_custom_checkbox = gr.Checkbox(value=False, label="Custom Checkbbox", scale = 1, visible= False, show_label= True, ) else: - custom_choices = "#" in video_prompt_type_value custom_video_choices = custom_video_selection["choices"] - custom_checkbox = len(custom_video_choices) <= 2 + custom_video_trigger = custom_video_selection.get("trigger", "") + custom_choices = len(custom_video_trigger) == 0 or custom_video_trigger in video_prompt_type_value + custom_checkbox = custom_video_selection.get("type","") == "checkbox" video_prompt_type_video_custom_label = custom_video_selection.get("label", "Custom Choices") video_prompt_type_video_custom_dropbox = gr.Dropdown( From a1aa82289bde197d377e76f96a421d2c76997009 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 28 Sep 2025 23:35:15 +0200 Subject: [PATCH 055/155] pinned onnx version --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e567f3ae4..354d465a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ loguru opencv-python>=4.12.0.88 segment-anything rembg[gpu]==2.0.65 -onnxruntime-gpu +onnxruntime-gpu==1.22 decord timm # insightface==0.7.3 From 12e81fa84051cc9680676a378b3eeb062bf4d414 Mon Sep 17 00:00:00 2001 From: SandSeppel <65561725+SandSeppel@users.noreply.github.com> Date: Mon, 29 Sep 2025 13:06:05 +0200 Subject: [PATCH 056/155] Added conda activate in README update --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ee432d770..9ad477a8a 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,7 @@ If using Pinokio use Pinokio to update otherwise: Get in the directory where WanGP is installed and: ```bash git pull +conda activate wan2gp pip install -r requirements.txt ``` From 73b22f8a04d1d4df27e398d2b9931279737b644e Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 30 Sep 2025 10:51:25 +0200 Subject: [PATCH 057/155] combinatorics --- README.md | 9 + configs/lynx.json | 15 + configs/vace_lynx_14B.json | 17 ++ defaults/lynx.json | 18 ++ defaults/vace_lynx_14B.json | 9 + models/wan/any2video.py | 98 ++++-- models/wan/lynx/attention_processor.py | 41 +++ models/wan/lynx/flash_attention.py | 60 ---- models/wan/lynx/full_attention_processor.py | 305 ------------------- models/wan/lynx/inference_utils.py | 39 --- models/wan/lynx/light_attention_processor.py | 194 ------------ models/wan/lynx/model_utils_wan.py | 264 ---------------- models/wan/modules/model.py | 83 +++-- models/wan/wan_handler.py | 70 +++-- requirements.txt | 5 +- shared/utils/utils.py | 2 +- wgp.py | 56 ++-- 17 files changed, 325 insertions(+), 960 deletions(-) create mode 100644 configs/lynx.json create mode 100644 configs/vace_lynx_14B.json create mode 100644 defaults/lynx.json create mode 100644 defaults/vace_lynx_14B.json create mode 100644 models/wan/lynx/attention_processor.py delete mode 100644 models/wan/lynx/flash_attention.py delete mode 100644 models/wan/lynx/full_attention_processor.py delete mode 100644 models/wan/lynx/inference_utils.py delete mode 100644 models/wan/lynx/light_attention_processor.py delete mode 100644 models/wan/lynx/model_utils_wan.py diff --git a/README.md b/README.md index ee432d770..6abdaf5c2 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,15 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### September 30 2025: WanGP v8.9 - Combinatorics + +This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps. + +*WanGP 8.9* also illustrate how existing WanGP features can be easily combined with new models. For instance with *Lynx* you will get out of the box *Video to Video* and *Image/Text to Image*. + +Another fun combination is *Vace* + *Lynx*, which works much better than *Vace StandIn*. I have added sliders to change the weight of Vace & Lynx to allow you to tune the effects. + + ### September 28 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: diff --git a/configs/lynx.json b/configs/lynx.json new file mode 100644 index 000000000..a973d5e68 --- /dev/null +++ b/configs/lynx.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "lynx": "full" +} diff --git a/configs/vace_lynx_14B.json b/configs/vace_lynx_14B.json new file mode 100644 index 000000000..770cbeef6 --- /dev/null +++ b/configs/vace_lynx_14B.json @@ -0,0 +1,17 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], + "vace_in_dim": 96, + "lynx": "full" +} diff --git a/defaults/lynx.json b/defaults/lynx.json new file mode 100644 index 000000000..4c0a17f1e --- /dev/null +++ b/defaults/lynx.json @@ -0,0 +1,18 @@ +{ + "model": { + "name": "Wan2.1 Lynx 14B", + "modules": [ + [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_quanto_bf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_quanto_fp16_int8.safetensors" + ] + ], + "architecture": "lynx", + "description": "The Lynx ControlNet offers State of the Art Identity Preservation. You need to provide a Reference Image which is a close up of a person face to transfer this person in the Video.", + "URLs": "t2v", + "preload_URLs": [ + "wan2.1_lynx_full_arc_resampler.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/vace_lynx_14B.json b/defaults/vace_lynx_14B.json new file mode 100644 index 000000000..ba8bfbd7b --- /dev/null +++ b/defaults/vace_lynx_14B.json @@ -0,0 +1,9 @@ +{ + "model": { + "name": "Vace Lynx 14B", + "architecture": "vace_lynx_14B", + "modules": [ "vace_14B", "lynx"], + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Lynx version is specialized in transferring a Person Identity.", + "URLs": "t2v" + } +} \ No newline at end of file diff --git a/models/wan/any2video.py b/models/wan/any2video.py index c9b1a29a4..5beea8e74 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -286,7 +286,32 @@ def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=No msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) msk = msk.transpose(1,2)[0] return msk - + + def encode_reference_images(self, ref_images, ref_prompt="image of a face", any_guidance= False, tile_size = None): + ref_images = [convert_image_to_tensor(img).unsqueeze(1).to(device=self.device, dtype=self.dtype) for img in ref_images] + shape = ref_images[0].shape + freqs = get_rotary_pos_embed( (len(ref_images) , shape[-2] // 8, shape[-1] // 8 )) + # batch_ref_image: [B, C, F, H, W] + vae_feat = self.vae.encode(ref_images, tile_size = tile_size) + vae_feat = torch.cat( vae_feat, dim=1).unsqueeze(0) + if any_guidance: + vae_feat_uncond = self.vae.encode([ref_images[0] * 0], tile_size = tile_size) * len(ref_images) + vae_feat_uncond = torch.cat( vae_feat_uncond, dim=1).unsqueeze(0) + context = self.text_encoder([ref_prompt], self.device)[0].to(self.dtype) + context = torch.cat([context, context.new_zeros(self.model.text_len -context.size(0), context.size(1)) ]).unsqueeze(0) + clear_caches() + get_cache("lynx_ref_buffer").update({ 0: {}, 1: {} }) + ref_buffer = self.model( + pipeline =self, + x = [vae_feat, vae_feat_uncond] if any_guidance else [vae_feat], + context = [context, context] if any_guidance else [context], + freqs= freqs, + t=torch.stack([torch.tensor(0, dtype=torch.float)]).to(self.device), + lynx_feature_extractor = True, + ) + clear_caches() + return ref_buffer[0], (ref_buffer[1] if any_guidance else None) + def generate(self, input_prompt, input_frames= None, @@ -355,6 +380,7 @@ def generate(self, video_prompt_type= "", original_input_ref_images = [], face_arc_embeds = None, + control_scale_alt = 1., **bbargs ): @@ -401,13 +427,15 @@ def generate(self, # Text Encoder if n_prompt == "": n_prompt = self.sample_neg_prompt - context = self.text_encoder([input_prompt], self.device)[0] - context_null = self.text_encoder([n_prompt], self.device)[0] - context = context.to(self.dtype) - context_null = context_null.to(self.dtype) text_len = self.model.text_len - context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) - context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) + any_guidance_at_all = guide_scale > 1 or guide2_scale > 1 and guide_phases >=2 or guide3_scale > 1 and guide_phases >=3 + context = self.text_encoder([input_prompt], self.device)[0].to(self.dtype) + context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) + if NAG_scale > 1 or any_guidance_at_all: + context_null = self.text_encoder([n_prompt], self.device)[0].to(self.dtype) + context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) + else: + context_null = None if input_video is not None: height, width = input_video.shape[-2:] # NAG_prompt = "static, low resolution, blurry" @@ -423,13 +451,13 @@ def generate(self, # if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0) if self._interrupt: return None - vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B"] + vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] phantom = model_type in ["phantom_1.3B", "phantom_14B"] fantasy = model_type in ["fantasy"] multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] infinitetalk = model_type in ["infinitetalk"] standin = model_type in ["standin", "vace_standin_14B"] - lynx = model_type in ["lynx_lite", "lynx_full", "vace_lynx_lite_14B", "vace_lynx_full_14B"] + lynx = model_type in ["lynx_lite", "lynx", "vace_lynx_lite_14B", "vace_lynx_14B"] recam = model_type in ["recam_1.3B"] ti2v = model_type in ["ti2v_2_2", "lucy_edit"] lucy_edit= model_type in ["lucy_edit"] @@ -651,7 +679,9 @@ def generate(self, if vace : # vace context encode input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)]) - input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) + input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) + if model_type in ["vace_lynx_14B"]: + input_ref_images = input_ref_masks = None input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images] input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks] ref_images_before = True @@ -704,21 +734,36 @@ def generate(self, kwargs["freqs"] = freqs # Lynx - if lynx: - from .lynx.resampler import Resampler - from accelerate import init_empty_weights - lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"] - with init_empty_weights(): - arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 ) - offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) - arc_resampler.to(self.device) - arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float) - ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype) - ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype) - arc_resampler = None - gc.collect() - torch.cuda.empty_cache() - kwargs["lynx_ip_scale"] = 1. + if lynx : + if original_input_ref_images is None or len(original_input_ref_images) == 0: + lynx = False + else: + from .lynx.resampler import Resampler + from accelerate import init_empty_weights + lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"] + ip_hidden_states = ip_hidden_states_uncond = None + if True: + with init_empty_weights(): + arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 ) + offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) + arc_resampler.to(self.device) + arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float) + ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype) + ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype) + arc_resampler = None + if not lynx_lite: + standin_ref_pos = -1 + image_ref = original_input_ref_images[standin_ref_pos] + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + lynx_ref = face_processor.process(image_ref, resize_to = 256 ) + lynx_ref_buffer, lynx_ref_buffer_uncond = self.encode_reference_images([lynx_ref], tile_size=VAE_tile_size, any_guidance= any_guidance_at_all) + lynx_ref = None + gc.collect() + torch.cuda.empty_cache() + vace_lynx = model_type in ["vace_lynx_14B"] + kwargs["lynx_ip_scale"] = control_scale_alt + kwargs["lynx_ref_scale"] = control_scale_alt #Standin if standin: @@ -881,6 +926,9 @@ def clear(): "context" : [context, context_null], "lynx_ip_embeds": [ip_hidden_states, ip_hidden_states_uncond] } + if model_type in ["lynx", "vace_lynx_14B"]: + gen_args["lynx_ref_buffer"] = [lynx_ref_buffer, lynx_ref_buffer_uncond] + elif multitalk and audio_proj != None: if guide_scale == 1: gen_args = { diff --git a/models/wan/lynx/attention_processor.py b/models/wan/lynx/attention_processor.py new file mode 100644 index 000000000..c64291d1d --- /dev/null +++ b/models/wan/lynx/attention_processor.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 The Wan Team and The HuggingFace Team. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# +# Original file was released under Apache License 2.0, with the full license text +# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. +# +# This modified file is released under the same license. + + +from typing import Optional, List +import torch +import torch.nn as nn +from diffusers.models.normalization import RMSNorm + +def setup_lynx_attention_layers(blocks, lynx_full, dim): + if lynx_full: + lynx_cross_dim = 5120 + lynx_layers = len(blocks) + else: + lynx_cross_dim = 2048 + lynx_layers = 20 + for i, block in enumerate(blocks): + if i < lynx_layers: + block.cross_attn.to_k_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full) + block.cross_attn.to_v_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full) + else: + block.cross_attn.to_k_ip = None + block.cross_attn.to_v_ip = None + if lynx_full: + block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5) + block.cross_attn.norm_rms_k = None + block.self_attn.to_k_ref = nn.Linear(dim, dim, bias=True) + block.self_attn.to_v_ref = nn.Linear(dim, dim, bias=True) + else: + block.cross_attn.registers = None + block.cross_attn.norm_rms_k = RMSNorm(dim, eps=1e-5, elementwise_affine=False) + + + diff --git a/models/wan/lynx/flash_attention.py b/models/wan/lynx/flash_attention.py deleted file mode 100644 index 3bb807db8..000000000 --- a/models/wan/lynx/flash_attention.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# SPDX-License-Identifier: Apache-2.0 - -import torch -try: - from flash_attn_interface import flash_attn_varlen_func # flash attn 3 -except: - from flash_attn.flash_attn_interface import flash_attn_varlen_func # flash attn 2 - -def flash_attention(query, key, value, q_lens, kv_lens, causal=False): - """ - Args: - query, key, value: [B, H, T, D_h] - q_lens: list[int] per sequence query length - kv_lens: list[int] per sequence key/value length - causal: whether to use causal mask - - Returns: - output: [B, H, T_q, D_h] - """ - B, H, T_q, D_h = query.shape - T_k = key.shape[2] - device = query.device - - # Flatten: [B, H, T, D] -> [total_tokens, H, D] - q = query.permute(0, 2, 1, 3).reshape(B * T_q, H, D_h) - k = key.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h) - v = value.permute(0, 2, 1, 3).reshape(B * T_k, H, D_h) - - # Prepare cu_seqlens: prefix sum - q_lens_tensor = torch.tensor(q_lens, device=device, dtype=torch.int32) - kv_lens_tensor = torch.tensor(kv_lens, device=device, dtype=torch.int32) - - cu_seqlens_q = torch.zeros(len(q_lens_tensor) + 1, device=device, dtype=torch.int32) - cu_seqlens_k = torch.zeros(len(kv_lens_tensor) + 1, device=device, dtype=torch.int32) - - cu_seqlens_q[1:] = torch.cumsum(q_lens_tensor, dim=0) - cu_seqlens_k[1:] = torch.cumsum(kv_lens_tensor, dim=0) - - max_seqlen_q = int(q_lens_tensor.max().item()) - max_seqlen_k = int(kv_lens_tensor.max().item()) - - # Call FlashAttention varlen kernel - out = flash_attn_varlen_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, - causal=causal - ) - if not torch.is_tensor(out): # flash attn 3 - out = out[0] - - # Restore shape: [total_q, H, D_h] -> [B, H, T_q, D_h] - out = out.view(B, T_q, H, D_h).permute(0, 2, 1, 3).contiguous() - - return out \ No newline at end of file diff --git a/models/wan/lynx/full_attention_processor.py b/models/wan/lynx/full_attention_processor.py deleted file mode 100644 index fd0edb7a8..000000000 --- a/models/wan/lynx/full_attention_processor.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright (c) 2025 The Wan Team and The HuggingFace Team. -# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# -# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. -# -# Original file was released under Apache License 2.0, with the full license text -# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. -# -# This modified file is released under the same license. - - -from typing import Optional, List -import torch -import torch.nn as nn -from diffusers.models.attention_processor import Attention -from modules.common.flash_attention import flash_attention -from modules.common.navit_utils import vector_to_list, list_to_vector, merge_token_lists - -def new_forward( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - **cross_attention_kwargs, -) -> torch.Tensor: - return self.processor( - self, - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) -Attention.forward = new_forward - -class WanIPAttnProcessor(nn.Module): - def __init__( - self, - cross_attention_dim: int, - dim: int, - n_registers: int, - bias: bool, - ): - super().__init__() - self.to_k_ip = nn.Linear(cross_attention_dim, dim, bias=bias) - self.to_v_ip = nn.Linear(cross_attention_dim, dim, bias=bias) - if n_registers > 0: - self.registers = nn.Parameter(torch.randn(1, n_registers, cross_attention_dim) / dim**0.5) - else: - self.registers = None - - def forward( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - ip_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - q_lens: List[int] = None, - kv_lens: List[int] = None, - ip_lens: List[int] = None, - ip_scale: float = 1.0, - ): - encoder_hidden_states_img = None - if attn.add_k_proj is not None: - encoder_hidden_states_img = encoder_hidden_states[:, :257] - encoder_hidden_states = encoder_hidden_states[:, 257:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - if ip_hidden_states is not None: - - if self.registers is not None: - ip_hidden_states_list = vector_to_list(ip_hidden_states, ip_lens, 1) - ip_hidden_states_list = merge_token_lists(ip_hidden_states_list, [self.registers] * len(ip_hidden_states_list), 1) - ip_hidden_states, ip_lens = list_to_vector(ip_hidden_states_list, 1) - - ip_query = query - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - if attn.norm_q is not None: - ip_query = attn.norm_q(ip_query) - if attn.norm_k is not None: - ip_key = attn.norm_k(ip_key) - ip_query = ip_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - ip_key = ip_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - ip_value = ip_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - ip_hidden_states = flash_attention( - ip_query, ip_key, ip_value, - q_lens=q_lens, - kv_lens=ip_lens - ) - ip_hidden_states = ip_hidden_states.transpose(1, 2).flatten(2, 3) - ip_hidden_states = ip_hidden_states.type_as(ip_query) - - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - if rotary_emb is not None: - - def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): - x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) - x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) - - query = apply_rotary_emb(query, rotary_emb) - key = apply_rotary_emb(key, rotary_emb) - - hidden_states = flash_attention( - query, key, value, - q_lens=q_lens, - kv_lens=kv_lens, - ) - - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.type_as(query) - - if ip_hidden_states is not None: - hidden_states = hidden_states + ip_scale * ip_hidden_states - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - -def register_ip_adapter( - model, - cross_attention_dim=None, - n_registers=0, - init_method="zero", - dtype=torch.float32, -): - attn_procs = {} - transformer_sd = model.state_dict() - for layer_idx, block in enumerate(model.blocks): - name = f"blocks.{layer_idx}.attn2.processor" - layer_name = name.split(".processor")[0] - dim = transformer_sd[layer_name + ".to_k.weight"].shape[1] - attn_procs[name] = WanIPAttnProcessor( - cross_attention_dim=dim if cross_attention_dim is None else cross_attention_dim, - dim=dim, - n_registers=n_registers, - bias=True, - ) - if init_method == "zero": - torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight) - torch.nn.init.zeros_(attn_procs[name].to_k_ip.bias) - torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight) - torch.nn.init.zeros_(attn_procs[name].to_v_ip.bias) - elif init_method == "clone": - layer_name = name.split(".processor")[0] - weights = { - "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"], - "to_k_ip.bias": transformer_sd[layer_name + ".to_k.bias"], - "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"], - "to_v_ip.bias": transformer_sd[layer_name + ".to_v.bias"], - } - attn_procs[name].load_state_dict(weights) - elif init_method == "random": - pass - else: - raise ValueError(f"{init_method} is not supported.") - block.attn2.processor = attn_procs[name] - ip_layers = torch.nn.ModuleList(attn_procs.values()) - ip_layers.to(device=model.device, dtype=dtype) - return model, ip_layers - - -class WanRefAttnProcessor(nn.Module): - def __init__( - self, - dim: int, - bias: bool, - ): - super().__init__() - self.to_k_ref = nn.Linear(dim, dim, bias=bias) - self.to_v_ref = nn.Linear(dim, dim, bias=bias) - - def forward( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - q_lens: List[int] = None, - kv_lens: List[int] = None, - ref_feature: Optional[tuple] = None, - ref_scale: float = 1.0, - ): - encoder_hidden_states_img = None - if attn.add_k_proj is not None: - encoder_hidden_states_img = encoder_hidden_states[:, :257] - encoder_hidden_states = encoder_hidden_states[:, 257:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - if ref_feature is None: - ref_hidden_states = None - else: - ref_hidden_states, ref_lens = ref_feature - ref_query = query - ref_key = self.to_k_ref(ref_hidden_states) - ref_value = self.to_v_ref(ref_hidden_states) - if attn.norm_q is not None: - ref_query = attn.norm_q(ref_query) - if attn.norm_k is not None: - ref_key = attn.norm_k(ref_key) - ref_query = ref_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - ref_key = ref_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - ref_value = ref_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - ref_hidden_states = flash_attention( - ref_query, ref_key, ref_value, - q_lens=q_lens, - kv_lens=ref_lens - ) - ref_hidden_states = ref_hidden_states.transpose(1, 2).flatten(2, 3) - ref_hidden_states = ref_hidden_states.type_as(ref_query) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - if rotary_emb is not None: - - def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): - x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) - x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) - - query = apply_rotary_emb(query, rotary_emb) - key = apply_rotary_emb(key, rotary_emb) - - hidden_states = flash_attention( - query, key, value, - q_lens=q_lens, - kv_lens=kv_lens, - ) - - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.type_as(query) - - if ref_hidden_states is not None: - hidden_states = hidden_states + ref_scale * ref_hidden_states - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - return hidden_states - - -def register_ref_adapter( - model, - init_method="zero", - dtype=torch.float32, -): - attn_procs = {} - transformer_sd = model.state_dict() - for layer_idx, block in enumerate(model.blocks): - name = f"blocks.{layer_idx}.attn1.processor" - layer_name = name.split(".processor")[0] - dim = transformer_sd[layer_name + ".to_k.weight"].shape[0] - attn_procs[name] = WanRefAttnProcessor( - dim=dim, - bias=True, - ) - if init_method == "zero": - torch.nn.init.zeros_(attn_procs[name].to_k_ref.weight) - torch.nn.init.zeros_(attn_procs[name].to_k_ref.bias) - torch.nn.init.zeros_(attn_procs[name].to_v_ref.weight) - torch.nn.init.zeros_(attn_procs[name].to_v_ref.bias) - elif init_method == "clone": - weights = { - "to_k_ref.weight": transformer_sd[layer_name + ".to_k.weight"], - "to_k_ref.bias": transformer_sd[layer_name + ".to_k.bias"], - "to_v_ref.weight": transformer_sd[layer_name + ".to_v.weight"], - "to_v_ref.bias": transformer_sd[layer_name + ".to_v.bias"], - } - attn_procs[name].load_state_dict(weights) - elif init_method == "random": - pass - else: - raise ValueError(f"{init_method} is not supported.") - block.attn1.processor = attn_procs[name] - ref_layers = torch.nn.ModuleList(attn_procs.values()) - ref_layers.to(device=model.device, dtype=dtype) - return model, ref_layers diff --git a/models/wan/lynx/inference_utils.py b/models/wan/lynx/inference_utils.py deleted file mode 100644 index c1129f3bf..000000000 --- a/models/wan/lynx/inference_utils.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# SPDX-License-Identifier: Apache-2.0 - -from typing import Optional, Union - -import torch -import numpy as np - -from PIL import Image -from dataclasses import dataclass, field - - -dtype_mapping = { - "bf16": torch.bfloat16, - "fp16": torch.float16, - "fp32": torch.float32 -} - - -@dataclass -class SubjectInfo: - name: str = "" - image_pil: Optional[Image.Image] = None - landmarks: Optional[Union[np.ndarray, torch.Tensor]] = None - face_embeds: Optional[Union[np.ndarray, torch.Tensor]] = None - - -@dataclass -class VideoStyleInfo: # key names should match those used in style.yaml file - style_name: str = 'none' - num_frames: int = 81 - seed: int = -1 - guidance_scale: float = 5.0 - guidance_scale_i: float = 2.0 - num_inference_steps: int = 50 - width: int = 832 - height: int = 480 - prompt: str = '' - negative_prompt: str = '' \ No newline at end of file diff --git a/models/wan/lynx/light_attention_processor.py b/models/wan/lynx/light_attention_processor.py deleted file mode 100644 index a3c7516b1..000000000 --- a/models/wan/lynx/light_attention_processor.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright (c) 2025 The Wan Team and The HuggingFace Team. -# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# -# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. -# -# Original file was released under Apache License 2.0, with the full license text -# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. -# -# This modified file is released under the same license. - -from typing import Tuple, Union, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from diffusers.models.normalization import RMSNorm - - -def register_ip_adapter_wan( - model, - hidden_size=5120, - cross_attention_dim=2048, - dtype=torch.float32, - init_method="zero", - layers=None -): - attn_procs = {} - transformer_sd = model.state_dict() - - if layers is None: - layers = list(range(0, len(model.blocks))) - elif isinstance(layers, int): # Only interval provided - layers = list(range(0, len(model.blocks), layers)) - - for i, block in enumerate(model.blocks): - if i not in layers: - continue - - name = f"blocks.{i}.attn2.processor" - attn_procs[name] = IPAWanAttnProcessor2_0( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim - ) - - if init_method == "zero": - torch.nn.init.zeros_(attn_procs[name].to_k_ip.weight) - torch.nn.init.zeros_(attn_procs[name].to_v_ip.weight) - elif init_method == "clone": - layer_name = name.split(".processor")[0] - weights = { - "to_k_ip.weight": transformer_sd[layer_name + ".to_k.weight"], - "to_v_ip.weight": transformer_sd[layer_name + ".to_v.weight"], - } - attn_procs[name].load_state_dict(weights) - else: - raise ValueError(f"{init_method} is not supported.") - - block.attn2.processor = attn_procs[name] - - ip_layers = torch.nn.ModuleList(attn_procs.values()) - ip_layers.to(model.device, dtype=dtype) - - return model, ip_layers - - -class IPAWanAttnProcessor2_0(torch.nn.Module): - def __init__( - self, - hidden_size, - cross_attention_dim=None, - scale=1.0, - bias=False, - ): - super().__init__() - - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("IPAWanAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=bias) - - torch.nn.init.zeros_(self.to_k_ip.weight) - torch.nn.init.zeros_(self.to_v_ip.weight) - if bias: - torch.nn.init.zeros_(self.to_k_ip.bias) - torch.nn.init.zeros_(self.to_v_ip.bias) - - self.norm_rms_k = RMSNorm(hidden_size, eps=1e-5, elementwise_affine=False) - - - def __call__( - self, - attn, - hidden_states: torch.Tensor, - image_embed: torch.Tensor = None, - ip_scale: float = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - rotary_emb: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - encoder_hidden_states_img = None - if attn.add_k_proj is not None: - encoder_hidden_states_img = encoder_hidden_states[:, :257] - encoder_hidden_states = encoder_hidden_states[:, 257:] - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - # ============================================================= - batch_size = image_embed.size(0) - - ip_hidden_states = image_embed - ip_query = query # attn.to_q(hidden_states.clone()) - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - if attn.norm_q is not None: - ip_query = attn.norm_q(ip_query) - ip_key = self.norm_rms_k(ip_key) - - ip_inner_dim = ip_key.shape[-1] - ip_head_dim = ip_inner_dim // attn.heads - - ip_query = ip_query.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2) - ip_key = ip_key.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, ip_head_dim).transpose(1, 2) - - ip_hidden_states = F.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * ip_head_dim) - ip_hidden_states = ip_hidden_states.to(ip_query.dtype) - # =========================================================================== - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) - key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - if rotary_emb is not None: - - def apply_rotary_emb_inner(hidden_states: torch.Tensor, freqs: torch.Tensor): - x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) - x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) - return x_out.type_as(hidden_states) - - query = apply_rotary_emb_inner(query, rotary_emb) - key = apply_rotary_emb_inner(key, rotary_emb) - - # I2V task - hidden_states_img = None - if encoder_hidden_states_img is not None: - key_img = attn.add_k_proj(encoder_hidden_states_img) - key_img = attn.norm_added_k(key_img) - value_img = attn.add_v_proj(encoder_hidden_states_img) - - key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - - hidden_states_img = F.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False - ) - hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) - hidden_states_img = hidden_states_img.type_as(query) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) - hidden_states = hidden_states.type_as(query) - - if hidden_states_img is not None: - hidden_states = hidden_states + hidden_states_img - - # Add IPA residual - ip_scale = ip_scale or self.scale - hidden_states = hidden_states + ip_scale * ip_hidden_states - - hidden_states = attn.to_out[0](hidden_states) - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states \ No newline at end of file diff --git a/models/wan/lynx/model_utils_wan.py b/models/wan/lynx/model_utils_wan.py deleted file mode 100644 index 2fd117046..000000000 --- a/models/wan/lynx/model_utils_wan.py +++ /dev/null @@ -1,264 +0,0 @@ -# Copyright (c) 2025 The Wan Team and The HuggingFace Team. -# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates -# SPDX-License-Identifier: Apache-2.0 -# -# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. -# -# Original file was released under Apache License 2.0, with the full license text -# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. -# -# This modified file is released under the same license. - -import os -import math -import torch - -from typing import List, Optional, Tuple, Union - -from diffusers.models.embeddings import get_3d_rotary_pos_embed -from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid - -from diffusers import FlowMatchEulerDiscreteScheduler - -""" Utility functions for Wan2.1-T2I model """ - - -def cal_mean_and_std(vae, device, dtype): - # https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py#L483 - - # latents_mean - latents_mean = torch.tensor(vae.config.latents_mean).view( - 1, vae.config.z_dim, 1, 1, 1).to(device, dtype) - - # latents_std - latents_std = 1.0 / torch.tensor(vae.config.latents_std).view( - 1, vae.config.z_dim, 1, 1, 1).to(device, dtype) - - return latents_mean, latents_std - - -def encode_video(vae, video): - # Input video is formatted as [B, F, C, H, W] - video = video.to(vae.device, dtype=vae.dtype) - video = video.unsqueeze(0) if len(video.shape) == 4 else video - - # VAE takes [B, C, F, H, W] - video = video.permute(0, 2, 1, 3, 4) - latent_dist = vae.encode(video).latent_dist - - mean, std = cal_mean_and_std(vae, vae.device, vae.dtype) - - # NOTE: this implementation is weird but I still borrow it from the official - # codes. A standard normailization should be (x - mean) / std, but the std - # here is computed by `1 / vae.config.latents_std` from cal_mean_and_std(). - - # Sample latent with scaling - video_latent = (latent_dist.sample() - mean) * std - - # Return as [B, C, F, H, W] as required by Wan - return video_latent - - -def decode_latent(vae, latent): - # Input latent is formatted as [B, C, F, H, W] - latent = latent.to(vae.device, dtype=vae.dtype) - latent = latent.unsqueeze(0) if len(latent.shape) == 4 else latent - - mean, std = cal_mean_and_std(vae, latent.device, latent.dtype) - latent = latent / std + mean - - # VAE takes [B, C, F, H, W] - frames = vae.decode(latent).sample - - # Return as [B, F, C, H, W] - frames = frames.permute(0, 2, 1, 3, 4).float() - - return frames - - -def encode_prompt( - tokenizer, - text_encoder, - prompt: Union[str, List[str]], - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - text_input_ids=None, - prompt_attention_mask=None, -): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if tokenizer is not None: - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_attention_mask = text_inputs.attention_mask - else: - if text_input_ids is None: - raise ValueError("`text_input_ids` must be provided when the tokenizer is not specified.") - - seq_lens = prompt_attention_mask.gt(0).sum(dim=1).long() - prompt_embeds = text_encoder(text_input_ids.to(device), prompt_attention_mask.to(device)).last_hidden_state - - prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds, prompt_attention_mask - - -def compute_prompt_embeddings( - tokenizer, text_encoder, prompt, max_sequence_length, device, dtype, requires_grad: bool = False -): - if requires_grad: - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompt, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - else: - with torch.no_grad(): - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompt, - num_videos_per_prompt=1, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) - return prompt_embeds[0] - - -def prepare_rotary_positional_embeddings( - height: int, - width: int, - num_frames: int, - vae_scale_factor_spatial: int = 8, - patch_size: int = 2, - attention_head_dim: int = 64, - device: Optional[torch.device] = None, - base_height: int = 480, - base_width: int = 720, -) -> Tuple[torch.Tensor, torch.Tensor]: - grid_height = height // (vae_scale_factor_spatial * patch_size) - grid_width = width // (vae_scale_factor_spatial * patch_size) - base_size_width = base_width // (vae_scale_factor_spatial * patch_size) - base_size_height = base_height // (vae_scale_factor_spatial * patch_size) - - grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height) - freqs_cos, freqs_sin = get_3d_rotary_pos_embed( - embed_dim=attention_head_dim, - crops_coords=grid_crops_coords, - grid_size=(grid_height, grid_width), - temporal_size=num_frames, - ) - - freqs_cos = freqs_cos.to(device=device) - freqs_sin = freqs_sin.to(device=device) - return freqs_cos, freqs_sin - - -def prepare_sigmas( - scheduler: FlowMatchEulerDiscreteScheduler, - sigmas: torch.Tensor, - batch_size: int, - num_train_timesteps: int, - flow_weighting_scheme: str = "none", - flow_logit_mean: float = 0.0, - flow_logit_std: float = 1.0, - flow_mode_scale: float = 1.29, - device: torch.device = torch.device("cpu"), - generator: Optional[torch.Generator] = None, -) -> torch.Tensor: - if isinstance(scheduler, FlowMatchEulerDiscreteScheduler): - weights = compute_density_for_timestep_sampling( - weighting_scheme=flow_weighting_scheme, - batch_size=batch_size, - logit_mean=flow_logit_mean, - logit_std=flow_logit_std, - mode_scale=flow_mode_scale, - device=device, - generator=generator, - ) - indices = (weights * num_train_timesteps).long() - else: - raise ValueError(f"Unsupported scheduler type {type(scheduler)}") - - return sigmas[indices] - - -def compute_density_for_timestep_sampling( - weighting_scheme: str, - batch_size: int, - logit_mean: float = None, - logit_std: float = None, - mode_scale: float = None, - device: torch.device = torch.device("cpu"), - generator: Optional[torch.Generator] = None, -) -> torch.Tensor: - r""" - Compute the density for sampling the timesteps when doing SD3 training. - - Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528. - - SD3 paper reference: https://arxiv.org/abs/2403.03206v1. - """ - if weighting_scheme == "logit_normal": - # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). - u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator) - u = torch.nn.functional.sigmoid(u) - elif weighting_scheme == "mode": - u = torch.rand(size=(batch_size,), device=device, generator=generator) - u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - else: - u = torch.rand(size=(batch_size,), device=device, generator=generator) - return u - - -def expand_tensor_dims(tensor: torch.Tensor, ndim: int) -> torch.Tensor: - assert len(tensor.shape) <= ndim - return tensor.reshape(tensor.shape + (1,) * (ndim - len(tensor.shape))) - - -def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - r"""Forward process of flow matching.""" - return (1.0 - t) * x0 + t * n - - -def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: - r"""Loss target for flow matching.""" - return n - x0 - - -def prepare_loss_weights( - scheduler: FlowMatchEulerDiscreteScheduler, - sigmas: Optional[torch.Tensor] = None, - flow_weighting_scheme: str = "none", - -) -> torch.Tensor: - - assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler) - - from diffusers.training_utils import compute_loss_weighting_for_sd3 - - return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme) \ No newline at end of file diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 07e3e54e9..30f6172f0 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -131,7 +131,7 @@ def __init__(self, dim, eps=1e-5): self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - def forward(self, x): + def forward(self, x, in_place= True): r""" Args: x(Tensor): Shape [B, L, C] @@ -141,7 +141,10 @@ def forward(self, x): y = y.mean(dim=-1, keepdim=True) y += self.eps y.rsqrt_() - x *= y + if in_place: + x *= y + else: + x = x * y x *= self.weight return x # return self._norm(x).type_as(x) * self.weight @@ -274,7 +277,7 @@ def text_cross_attention(self, xlist, context, return_q = False): else: return x, None - def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1): + def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1, lynx_ref_buffer = None, lynx_ref_scale = 0, sub_x_no=0): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -287,7 +290,21 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function - q, k, v = self.q(x), self.k(x), self.v(x) + q = self.q(x) + ref_hidden_states = None + if not lynx_ref_buffer is None: + lynx_ref_features = lynx_ref_buffer[self.block_no] + if self.norm_q is not None: ref_query = self.norm_q(q, in_place = False) + ref_key = self.to_k_ref(lynx_ref_features) + ref_value = self.to_v_ref(lynx_ref_features) + if self.norm_k is not None: ref_key = self.norm_k(ref_key) + ref_query, ref_key, ref_value = ref_query.unflatten(2, (self.num_heads, -1)), ref_key.unflatten(2, (self.num_heads, -1)), ref_value.unflatten(2, (self.num_heads, -1)) + qkv_list = [ref_query, ref_key, ref_value ] + del ref_query, ref_key, ref_value + ref_hidden_states = pay_attention(qkv_list) + + k, v = self.k(x), self.v(x) + if standin_phase == 1: q += self.q_loras(x) k += self.k_loras(x) @@ -347,6 +364,8 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks ) del q,k,v + if ref_hidden_states is not None: + x.add_(ref_hidden_states, alpha= lynx_ref_scale) x = x.flatten(2) x = self.o(x) @@ -537,6 +556,10 @@ def forward( motion_vec = None, lynx_ip_embeds = None, lynx_ip_scale = 0, + lynx_ref_scale = 0, + lynx_feature_extractor = False, + lynx_ref_buffer = None, + sub_x_no =0, ): r""" Args: @@ -571,6 +594,7 @@ def forward( x_mod *= 1 + e[1] x_mod += e[0] x_mod = restore_latent_shape(x_mod) + if cam_emb != None: cam_emb = self.cam_encoder(cam_emb) cam_emb = cam_emb.repeat(1, 2, 1) @@ -579,8 +603,9 @@ def forward( x_mod += cam_emb xlist = [x_mod.to(attention_dtype)] + if lynx_feature_extractor: get_cache("lynx_ref_buffer")[sub_x_no][self.block_no] = xlist[0] del x_mod - y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, ) + y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, lynx_ref_buffer = lynx_ref_buffer, lynx_ref_scale = lynx_ref_scale, sub_x_no = sub_x_no) y = y.to(dtype) if cam_emb != None: y = self.projector(y) @@ -1073,30 +1098,9 @@ def __init__(self, block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) if lynx is not None: - from ..lynx.light_attention_processor import IPAWanAttnProcessor2_0 + from ..lynx.attention_processor import setup_lynx_attention_layers lynx_full = lynx=="full" - if lynx_full: - lynx_cross_dim = 5120 - lynx_layers = len(self.blocks) - else: - lynx_cross_dim = 2048 - lynx_layers = 20 - for i, block in enumerate(self.blocks): - if i < lynx_layers: - attn = IPAWanAttnProcessor2_0(hidden_size=dim, cross_attention_dim = lynx_cross_dim,) - block.cross_attn.to_k_ip = attn.to_k_ip - block.cross_attn.to_v_ip = attn.to_v_ip - else: - block.cross_attn.to_k_ip = None - block.cross_attn.to_v_ip = None - if lynx_full: - block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5) - block.cross_attn.norm_rms_k = None - else: - block.cross_attn.registers = None - block.cross_attn.norm_rms_k = attn.norm_rms_k - - + setup_lynx_attention_layers(self.blocks, lynx_full, dim) if animate: self.pose_patch_embedding = nn.Conv3d( @@ -1302,7 +1306,10 @@ def forward( pose_latents=None, face_pixel_values=None, lynx_ip_embeds = None, - lynx_ip_scale = 0, + lynx_ip_scale = 0, + lynx_ref_scale = 0, + lynx_feature_extractor = False, + lynx_ref_buffer = None, ): # patch_dtype = self.patch_embedding.weight.dtype @@ -1389,6 +1396,8 @@ def forward( audio_context_lens=audio_context_lens, ref_images_count=ref_images_count, lynx_ip_scale= lynx_ip_scale, + lynx_ref_scale = lynx_ref_scale, + lynx_feature_extractor = lynx_feature_extractor, ) _flag_df = t.dim() == 2 @@ -1413,6 +1422,12 @@ def forward( else: lynx_ip_embeds_list = lynx_ip_embeds + if lynx_ref_buffer is None: + lynx_ref_buffer_list = [None] * len(x_list) + else: + lynx_ref_buffer_list = lynx_ref_buffer + + if self.inject_sample_info and fps!=None: fps = torch.tensor(fps, dtype=torch.long, device=device) @@ -1562,11 +1577,11 @@ def forward( continue x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) else: - for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec, lynx_ip_embeds) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list, lynx_ip_embeds_list)): + for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec, lynx_ip_embeds,lynx_ref_buffer) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list, lynx_ip_embeds_list,lynx_ref_buffer_list)): if should_calc: - x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec, lynx_ip_embeds= lynx_ip_embeds, **kwargs) + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec, lynx_ip_embeds= lynx_ip_embeds, lynx_ref_buffer = lynx_ref_buffer, sub_x_no =i, **kwargs) del x - context = hints = audio_embedding = None + context = hints = None if skips_steps_cache != None: if joint_pass: @@ -1590,7 +1605,9 @@ def forward( torch.sub(x_list[0], ori_hidden_states[0], out=residual) skips_steps_cache.previous_residual[x_id] = residual residual, ori_hidden_states = None, None - + if lynx_feature_extractor: + return get_cache("lynx_ref_buffer") + for i, x in enumerate(x_list): if chipmunk: x = reverse_voxel_chunk_no_padding(x.transpose(1, 2).unsqueeze(-1), x_og_shape, voxel_shape).squeeze(-1) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index b1f594271..95a43ee18 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -18,7 +18,7 @@ def test_standin(base_model_type): return base_model_type in ["standin", "vace_standin_14B"] def test_lynx(base_model_type): - return base_model_type in ["lynx_lite", "vace_lynx_lite_14B", "lynx_full", "vace_lynx_full_14B"] + return base_model_type in ["lynx_lite", "vace_lynx_lite_14B", "lynx", "vace_lynx_14B"] def test_wan_5B(base_model_type): return base_model_type in ["ti2v_2_2", "lucy_edit"] @@ -86,7 +86,7 @@ def query_model_def(base_model_type, model_def): extra_model_def["multitalk_class"] = test_multitalk(base_model_type) extra_model_def["standin_class"] = test_standin(base_model_type) extra_model_def["lynx_class"] = test_lynx(base_model_type) - vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] + vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] extra_model_def["vace_class"] = vace_class if base_model_type in ["animate"]: @@ -140,6 +140,7 @@ def query_model_def(base_model_type, model_def): "selection":[ "", "A"], "visible": False } + extra_model_def["v2i_switch_supported"] = True if base_model_type in ["infinitetalk"]: @@ -213,15 +214,6 @@ def query_model_def(base_model_type, model_def): "visible": False } - # extra_model_def["image_ref_choices"] = { - # "choices": [("None", ""), - # ("People / Objects", "I"), - # ("Landscape followed by People / Objects (if any)", "KI"), - # ], - # "visible": False, - # "letters_filter": "KFI", - # } - extra_model_def["video_guide_outpainting"] = [0,1] extra_model_def["keep_frames_video_guide_not_supported"] = True extra_model_def["extract_guide_from_window_start"] = True @@ -257,17 +249,51 @@ def query_model_def(base_model_type, model_def): extra_model_def["guide_inpaint_color"] = 127.5 extra_model_def["forced_guide_mask_inputs"] = True extra_model_def["return_image_refs_tensor"] = True + extra_model_def["v2i_switch_supported"] = True + if base_model_type in ["vace_lynx_14B"]: + extra_model_def["set_video_prompt_type"]="Q" + extra_model_def["control_net_weight_alt_name"] = "Lynx" + extra_model_def["image_ref_choices"] = { "choices": [("None", ""),("Person Face", "I")], "letters_filter": "I"} + extra_model_def["no_background_removal"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 + - if base_model_type in ["standin", "lynx_lite", "lynx_full"]: + + if base_model_type in ["standin"]: + extra_model_def["v2i_switch_supported"] = True + + if base_model_type in ["lynx_lite", "lynx"]: extra_model_def["fit_into_canvas_image_refs"] = 0 + extra_model_def["guide_custom_choices"] = { + "choices":[("Use Reference Image which is a Person Face", ""), + ("Video to Video guided by Text Prompt & Reference Image", "GUV"), + ("Video to Video on the Area of the Video Mask", "GVA")], + "default": "", + "letters_filter": "GUVA", + "label": "Video to Video", + "show_label" : False, + } + + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A"], + "visible": False + } + extra_model_def["image_ref_choices"] = { "choices": [ ("No Reference Image", ""), ("Reference Image is a Person Face", "I"), ], + "visible": False, "letters_filter":"I", } + extra_model_def["set_video_prompt_type"]= "Q" + extra_model_def["no_background_removal"] = True + extra_model_def["v2i_switch_supported"] = True + extra_model_def["control_net_weight_alt_name"] = "Lynx" + + if base_model_type in ["phantom_1.3B", "phantom_14B"]: extra_model_def["image_ref_choices"] = { "choices": [("Reference Image", "I")], @@ -326,8 +352,8 @@ def query_model_def(base_model_type, model_def): @staticmethod def query_supported_types(): - return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", - "t2v_1.3B", "standin", "lynx_lite", "lynx_full", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", + return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B", + "t2v_1.3B", "standin", "lynx_lite", "lynx", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", "recam_1.3B", "animate", "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] @@ -338,11 +364,13 @@ def query_family_maps(): models_eqv_map = { "flf2v_720p" : "i2v", "t2v_1.3B" : "t2v", + "vace_standin_14B" : "vace_14B", + "vace_lynx_14B" : "vace_14B", } models_comp_map = { - "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_full_14B"], - "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx_full"], + "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B"], + "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx"], "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], "i2v_2_2" : ["i2v_2_2_multitalk"], "fantasy": ["multitalk"], @@ -523,15 +551,17 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "flow_shift": 7, # 11 for 720p "sliding_window_overlap" : 9, "video_prompt_type": "I", - "remove_background_images_ref" : 1, + "remove_background_images_ref" : 1 , }) - elif base_model_type in ["lynx_lite", "lynx_full"]: + + elif base_model_type in ["lynx_lite", "lynx"]: ui_defaults.update({ "guidance_scale": 5.0, "flow_shift": 7, # 11 for 720p "sliding_window_overlap" : 9, - "video_prompt_type": "QI", - "remove_background_images_ref" : 1, + "video_prompt_type": "I", + "denoising_strength": 0.8, + "remove_background_images_ref" : 0, }) elif base_model_type in ["phantom_1.3B", "phantom_14B"]: diff --git a/requirements.txt b/requirements.txt index 354d465a7..9b2c6eeb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,8 +35,9 @@ rembg[gpu]==2.0.65 onnxruntime-gpu==1.22 decord timm -# insightface==0.7.3 -# facexlib==0.3.0 +insightface @ https://github.com/deepbeepmeep/insightface/raw/refs/heads/master/wheels/insightface-0.7.3-cp310-cp310-win_amd64.whl ; sys_platform == "win32" and python_version == "3.10" +insightface==0.7.3 ; sys_platform == "linux" +facexlib==0.3.0 # Config & orchestration omegaconf diff --git a/shared/utils/utils.py b/shared/utils/utils.py index 00178aad1..a03927e62 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -205,7 +205,7 @@ def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_d frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100) return frame_height, frame_width -def rgb_bw_to_rgba_mask(img, thresh=127): +def rgb_bw_to_rgba_mask(img, thresh=127): a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent out.putalpha(a) # white where alpha=255 diff --git a/wgp.py b/wgp.py index e48070ca3..74223a6c9 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.761" +WanGP_version = "8.9" settings_version = 2.38 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -3492,7 +3492,23 @@ def check(src, cond): if video_apg_switch is not None and video_apg_switch != 0: values += ["on"] labels += ["APG"] - + + control_net_weight = "" + if test_vace_module(video_model_type): + video_control_net_weight = configs.get("control_net_weight", 1) + if len(filter_letters(video_video_prompt_type, video_guide_processes))> 1: + video_control_net_weight2 = configs.get("control_net_weight2", 1) + control_net_weight = f"Vace #1={video_control_net_weight}, Vace #2={video_control_net_weight2}" + else: + control_net_weight = f"Vace={video_control_net_weight}" + control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "") + if len(control_net_weight_alt_name) >0: + if len(control_net_weight): control_net_weight += ", " + control_net_weight += control_net_weight_alt_name + "=" + str(configs.get("control_net_alt", 1)) + if len(control_net_weight) > 0: + values += [control_net_weight] + labels += ["Control Net Weights"] + video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "") video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0) video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0) @@ -3736,7 +3752,7 @@ def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_fr alpha_mask[mask > 127] = 1 frame = frame * alpha_mask frame = Image.fromarray(frame) - face = face_processor.process(frame, resize_to=size, face_crop_scale = 1.2,) + face = face_processor.process(frame, resize_to=size) face_list.append(face) face_processor = None @@ -4513,6 +4529,7 @@ def generate_video( image_mask, control_net_weight, control_net_weight2, + control_net_weight_alt, mask_expand, audio_guide, audio_guide2, @@ -4577,6 +4594,9 @@ def remove_temp_filenames(temp_filenames_list): model_def = get_model_def(model_type) is_image = image_mode > 0 + set_video_prompt_type = model_def.get("set_video_prompt_type", None) + if set_video_prompt_type is not None: + video_prompt_type = add_to_sequence(video_prompt_type, set_video_prompt_type) if is_image: if min_frames_if_references >= 1000: video_length = min_frames_if_references - 1000 @@ -4728,7 +4748,7 @@ def remove_temp_filenames(temp_filenames_list): _, _, latent_size = get_model_min_frames_and_step(model_type) original_image_refs = image_refs - image_refs = None if image_refs is None else [] + image_refs # work on a copy as it is going to be modified + image_refs = None if image_refs is None else ([] + image_refs) # work on a copy as it is going to be modified # image_refs = None # nb_frames_positions= 0 # Output Video Ratio Priorities: @@ -5025,7 +5045,7 @@ def remove_temp_filenames(temp_filenames_list): # Generic Video Preprocessing process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) preprocess_type, preprocess_type2 = "raw", None - for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PEDSLCMU")): + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, video_guide_processes)): if process_num == 0: preprocess_type = process_map_video_guide.get(process_letter, "raw") else: @@ -5224,6 +5244,7 @@ def set_header_text(txt): audio_scale= audio_scale, audio_context_lens= audio_context_lens, context_scale = context_scale, + control_scale_alt = control_net_weight_alt, model_mode = model_mode, causal_block_size = 5, causal_attention = True, @@ -6231,7 +6252,10 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if not vace: pop += ["frames_positions", "control_net_weight", "control_net_weight2"] - + + if not len(model_def.get("control_net_weight_alt_name", "")) >0: + pop += ["control_net_alt"] + if model_def.get("video_guide_outpainting", None) is None: pop += ["video_guide_outpainting"] @@ -6679,6 +6703,7 @@ def save_inputs( image_mask, control_net_weight, control_net_weight2, + control_net_weight_alt, mask_expand, audio_guide, audio_guide2, @@ -6992,6 +7017,7 @@ def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_t return video_prompt_type all_guide_processes ="PDESLCMUVBH" +video_guide_processes = "PEDSLCMU" def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): model_type = state["model_type"] @@ -7440,13 +7466,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - standin = model_def.get("standin_class", False) infinitetalk = base_model_type in ["infinitetalk"] hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename - hunyuan_video_custom = "hunyuan_video_custom" in model_filename - hunyuan_video_custom = base_model_type in ["hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit"] - hunyuan_video_custom_audio = base_model_type in ["hunyuan_custom_audio"] hunyuan_video_custom_edit = base_model_type in ["hunyuan_custom_edit"] hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename flux = base_model_family in ["flux"] @@ -7460,7 +7482,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non image_prompt_type_value = "" video_prompt_type_value = "" any_start_image = any_end_image = any_reference_image = any_image_mask = False - v2i_switch_supported = (vace or t2v or standin) and not image_outputs + v2i_switch_supported = model_def.get("v2i_switch_supported", False) and not image_outputs ti2v_2_2 = base_model_type in ["ti2v_2_2"] gallery_height = 350 def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ): @@ -7539,7 +7561,6 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl guide_custom_choices = model_def.get("guide_custom_choices", None) image_ref_choices = model_def.get("image_ref_choices", None) - # with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or recammaster or (flux or qwen ) and model_reference_image and image_mode_value >=1) as video_prompt_column: with gr.Column(visible= guide_preprocessing is not None or mask_preprocessing is not None or guide_custom_choices is not None or image_ref_choices is not None) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) @@ -7923,10 +7944,11 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl choices= sample_solver_choices, visible= True, label= "Sampler Solver / Scheduler" ) flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs) - - with gr.Row(visible = vace) as control_net_weights_row: - control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) - control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) + control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "") + with gr.Row(visible = vace or len(control_net_weight_alt_name) >0 ) as control_net_weights_row: + control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Vace Weight #1", visible=vace) + control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Vace Weight #2", visible=vace) + control_net_weight_alt = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label=control_net_weight_alt_name + " Weight", visible=len(control_net_weight_alt_name) >0) negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = not (hunyuan_t2v or hunyuan_i2v or no_negative_prompt) ) with gr.Column(visible = vace or t2v or test_class_i2v(model_type)) as NAG_col: gr.Markdown("NAG enforces Negative Prompt even if no Guidance is set (CFG = 1), set NAG Scale to > 1 to enable it") @@ -8113,7 +8135,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai with gr.Row(): cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = any_cfg_zero) - with gr.Column(visible = (vace or t2v or standin) and image_outputs) as min_frames_if_references_col: + with gr.Column(visible = v2i_switch_supported and image_outputs) as min_frames_if_references_col: gr.Markdown("Generating a single Frame alone may not be sufficient to preserve Reference Image Identity / Control Image Information or simply to get a good Image Quality. A workaround is to generate a short Video and keep the First Frame.") min_frames_if_references = gr.Dropdown( choices=[ From 69ce2860b8fb137411a04496d26e4879d30592c3 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 30 Sep 2025 10:59:04 +0200 Subject: [PATCH 058/155] updated lynx defaults --- defaults/lynx.json | 2 +- defaults/vace_lynx_14B.json | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/defaults/lynx.json b/defaults/lynx.json index 4c0a17f1e..528f5ef68 100644 --- a/defaults/lynx.json +++ b/defaults/lynx.json @@ -12,7 +12,7 @@ "description": "The Lynx ControlNet offers State of the Art Identity Preservation. You need to provide a Reference Image which is a close up of a person face to transfer this person in the Video.", "URLs": "t2v", "preload_URLs": [ - "wan2.1_lynx_full_arc_resampler.safetensors" + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_arc_resampler.safetensors" ] } } \ No newline at end of file diff --git a/defaults/vace_lynx_14B.json b/defaults/vace_lynx_14B.json index ba8bfbd7b..523d5b3fc 100644 --- a/defaults/vace_lynx_14B.json +++ b/defaults/vace_lynx_14B.json @@ -4,6 +4,7 @@ "architecture": "vace_lynx_14B", "modules": [ "vace_14B", "lynx"], "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Lynx version is specialized in transferring a Person Identity.", - "URLs": "t2v" + "URLs": "t2v", + "preload_URLs": "lynx" } } \ No newline at end of file From 1c27d496ae6246f716b04e4e53df55590cfe7969 Mon Sep 17 00:00:00 2001 From: Gunther Schulz Date: Tue, 30 Sep 2025 17:21:00 +0200 Subject: [PATCH 059/155] Add LCM + LTX sampler for Lightning LoRA compatibility - Implement LCMScheduler with RectifiedFlow (LTX) dynamics - Combine Latent Consistency Model with rectified flow scheduling - Optimize for 2-8 step ultra-fast inference with Lightning LoRAs - Add proper flow matching dynamics with shift parameter support - Update UI to show 'lcm + ltx' option in sampler dropdown --- .gitignore | 4 ++ models/wan/any2video.py | 13 ++++- models/wan/wan_handler.py | 3 +- shared/utils/lcm_scheduler.py | 99 +++++++++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+), 2 deletions(-) create mode 100644 shared/utils/lcm_scheduler.py diff --git a/.gitignore b/.gitignore index a84122c61..8049651d2 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,7 @@ gradio_outputs/ ckpts/ loras/ loras_i2v/ + +settings/ + +wgp_config.json diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 5beea8e74..5cf1e712a 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -32,11 +32,11 @@ from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed from shared.utils.vace_preprocessor import VaceVideoProcessor from shared.utils.basic_flowmatch import FlowMatchScheduler +from shared.utils.lcm_scheduler import LCMScheduler from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask from shared.utils.audio_video import save_video from mmgp import safetensors2 -from shared.utils.audio_video import save_video def optimized_scale(positive_flat, negative_flat): @@ -413,6 +413,17 @@ def generate(self, sample_scheduler, device=self.device, sigmas=sampling_sigmas) + elif sample_solver == 'lcm': + # LCM + LTX scheduler: Latent Consistency Model with RectifiedFlow + # Optimized for Lightning LoRAs with ultra-fast 2-8 step inference + effective_steps = min(sampling_steps, 8) # LCM works best with few steps + sample_scheduler = LCMScheduler( + num_train_timesteps=self.num_train_timesteps, + num_inference_steps=effective_steps, + shift=shift + ) + sample_scheduler.set_timesteps(effective_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps else: raise NotImplementedError(f"Unsupported Scheduler {sample_solver}") original_timesteps = timesteps diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 95a43ee18..de1212913 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -122,7 +122,8 @@ def query_model_def(base_model_type, model_def): ("unipc", "unipc"), ("euler", "euler"), ("dpm++", "dpm++"), - ("flowmatch causvid", "causvid"), ] + ("flowmatch causvid", "causvid"), + ("lcm + ltx", "lcm"), ] }) diff --git a/shared/utils/lcm_scheduler.py b/shared/utils/lcm_scheduler.py new file mode 100644 index 000000000..0fa1d22a3 --- /dev/null +++ b/shared/utils/lcm_scheduler.py @@ -0,0 +1,99 @@ +""" +LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow (LTX). +Optimized for Lightning LoRA compatibility and ultra-fast inference. +""" + +import torch +import math +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + + +class LCMScheduler(SchedulerMixin): + """ + LCM + LTX scheduler combining Latent Consistency Model with RectifiedFlow. + - LCM: Enables 2-8 step inference with consistency models + - LTX: Uses RectifiedFlow for better flow matching dynamics + Optimized for Lightning LoRAs and ultra-fast, high-quality generation. + """ + + def __init__(self, num_train_timesteps: int = 1000, num_inference_steps: int = 4, shift: float = 1.0): + self.num_train_timesteps = num_train_timesteps + self.num_inference_steps = num_inference_steps + self.shift = shift + self._step_index = None + + def set_timesteps(self, num_inference_steps: int, device=None, shift: float = None, **kwargs): + """Set timesteps for LCM+LTX inference using RectifiedFlow approach""" + self.num_inference_steps = min(num_inference_steps, 8) # LCM works best with 2-8 steps + + if shift is None: + shift = self.shift + + # RectifiedFlow (LTX) approach: Use rectified flow dynamics for better sampling + # This creates a more optimal path through the probability flow ODE + t = torch.linspace(0, 1, self.num_inference_steps + 1, dtype=torch.float32) + + # Apply rectified flow transformation for better dynamics + # This is the key LTX component - rectified flow scheduling + sigma_max = 1.0 + sigma_min = 0.003 / 1.002 + + # Rectified flow uses a more sophisticated sigma schedule + # that accounts for the flow matching dynamics + sigmas = sigma_min + (sigma_max - sigma_min) * (1 - t) + + # Apply shift for flow matching (similar to other flow-based schedulers) + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = self.sigmas[:-1] * self.num_train_timesteps + + if device is not None: + self.timesteps = self.timesteps.to(device) + self.sigmas = self.sigmas.to(device) + self._step_index = None + + def step(self, model_output: torch.Tensor, timestep: torch.Tensor, sample: torch.Tensor, **kwargs) -> SchedulerOutput: + """ + Perform LCM + LTX step combining consistency model with rectified flow. + - LCM: Direct consistency model prediction for fast inference + - LTX: RectifiedFlow dynamics for optimal probability flow path + """ + if self._step_index is None: + self._init_step_index(timestep) + + # Get current and next sigma values from RectifiedFlow schedule + sigma = self.sigmas[self._step_index] + if self._step_index + 1 < len(self.sigmas): + sigma_next = self.sigmas[self._step_index + 1] + else: + sigma_next = torch.zeros_like(sigma) + + # LCM + LTX: Combine consistency model approach with rectified flow dynamics + # The model_output represents the velocity field in the rectified flow ODE + # LCM allows us to take larger steps while maintaining consistency + + # RectifiedFlow step: x_{t+1} = x_t + v_θ(x_t, t) * (σ_next - σ) + # This is the core flow matching equation with LTX rectified dynamics + sigma_diff = (sigma_next - sigma) + while len(sigma_diff.shape) < len(sample.shape): + sigma_diff = sigma_diff.unsqueeze(-1) + + # LCM consistency: The model is trained to be consistent across timesteps + # allowing for fewer steps while maintaining quality + prev_sample = sample + model_output * sigma_diff + self._step_index += 1 + + return SchedulerOutput(prev_sample=prev_sample) + + def _init_step_index(self, timestep): + """Initialize step index based on current timestep""" + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + indices = (self.timesteps == timestep).nonzero() + if len(indices) > 0: + self._step_index = indices[0].item() + else: + # Find closest timestep if exact match not found + diffs = torch.abs(self.timesteps - timestep) + self._step_index = torch.argmin(diffs).item() From 7f0abc94d2d5f94019ce28d4a8582647c942c13f Mon Sep 17 00:00:00 2001 From: Gunther Schulz Date: Wed, 1 Oct 2025 20:36:50 +0200 Subject: [PATCH 060/155] Add human-readable time format to generation time display - Added format_generation_time() function to display generation time with human-readable format - Shows raw seconds with minutes/hours in parentheses when over 60 seconds - Format: '335s (5m 35s)' for times over 1 minute, '7384s (2h 3m 4s)' for times over 1 hour - Improves UX by making long generation times easier to understand at a glance --- wgp.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index 74223a6c9..5ec0583a5 100644 --- a/wgp.py +++ b/wgp.py @@ -148,6 +148,24 @@ def format_time(seconds): else: return f"{seconds:.1f}s" +def format_generation_time(seconds): + """Format generation time showing raw seconds with human-readable time in parentheses when over 60s""" + raw_seconds = f"{int(seconds)}s" + + if seconds < 60: + return raw_seconds + + hours = int(seconds // 3600) + minutes = int((seconds % 3600) // 60) + secs = int(seconds % 60) + + if hours > 0: + human_readable = f"{hours}h {minutes}m {secs}s" + else: + human_readable = f"{minutes}m {secs}s" + + return f"{raw_seconds} ({human_readable})" + def pil_to_base64_uri(pil_image, format="png", quality=75): if pil_image is None: return None @@ -3459,7 +3477,7 @@ def check(src, cond): video_num_inference_steps = configs.get("num_inference_steps", 0) video_creation_date = str(get_file_creation_date(file_name)) if "." in video_creation_date: video_creation_date = video_creation_date[:video_creation_date.rfind(".")] - video_generation_time = str(configs.get("generation_time", "0")) + "s" + video_generation_time = format_generation_time(float(configs.get("generation_time", "0"))) video_activated_loras = configs.get("activated_loras", []) video_loras_multipliers = configs.get("loras_multipliers", "") video_loras_multipliers = preparse_loras_multipliers(video_loras_multipliers) From fbf99df32eefed58d66fd810dd944f23738290a8 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 2 Oct 2025 19:29:45 +0200 Subject: [PATCH 061/155] Lora Reloaded --- README.md | 15 + defaults/animate.json | 1 + defaults/hunyuan_t2v_fast.json | 1 + defaults/i2v_fusionix.json | 1 + defaults/lucy_edit.json | 1 + defaults/lucy_edit_fastwan.json | 1 + defaults/ti2v_2_2.json | 1 + defaults/ti2v_2_2_fastwan.json | 1 + defaults/vace_14B_cocktail.json | 1 + defaults/vace_14B_cocktail_2_2.json | 1 + defaults/vace_14B_fusionix.json | 1 + docs/LORAS.md | 64 ++- loras/wan/Text2Video Cocktail - 10 Steps.json | 13 + loras/wan/Text2Video FusioniX - 10 Steps.json | 11 + .../Lightning v1.0 2 Phases - 4 Steps.json | 14 + .../Lightning v1.0 3 Phases - 8 Steps.json | 18 + .../Lightning v2025-09-28 - 4 Steps.json | 14 + .../Text2Video FusioniX - 10 Steps.json | 11 + loras/wan2_2_5B/FastWan 3 Steps.json | 7 + .../wan/Image2Video FusioniX - 10 Steps.json | 10 + loras_qwen/qwen/Lightning v1.0 - 4 Steps.json | 8 + loras_qwen/qwen/Lightning v1.1 - 8 Steps.json | 8 + models/wan/wan_handler.py | 2 +- shared/attention.py | 1 + wgp.py | 385 ++++++++++++------ 25 files changed, 458 insertions(+), 133 deletions(-) create mode 100644 loras/wan/Text2Video Cocktail - 10 Steps.json create mode 100644 loras/wan/Text2Video FusioniX - 10 Steps.json create mode 100644 loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json create mode 100644 loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json create mode 100644 loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json create mode 100644 loras/wan2_2/Text2Video FusioniX - 10 Steps.json create mode 100644 loras/wan2_2_5B/FastWan 3 Steps.json create mode 100644 loras_i2v/wan/Image2Video FusioniX - 10 Steps.json create mode 100644 loras_qwen/qwen/Lightning v1.0 - 4 Steps.json create mode 100644 loras_qwen/qwen/Lightning v1.1 - 8 Steps.json diff --git a/README.md b/README.md index 6abdaf5c2..637b1640d 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,21 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : +### October 2 2025: WanGP v8.99 - One Last Thing before the Big Unknown ... + +This new version hasn't any new model... + +...but temptation to upgrade will be high as it contains a few Loras related features that may change your Life: +- **Ready to use Loras Accelerators Profiles** per type of model that you can apply on your current *Generation Settings*. Next time I will recommend a *Lora Accelerator*, it will be only one click away. And best of all of the required Loras will be downloaded automatically. When you apply an *Accelerator Profile*, input fields like the *Number of Denoising Steps* *Activated Loras*, *Loras Multipliers* (such as "1;0 0;1" ...) will be automatically filled. However your video specific fields will be preserved, so it will be easy to switch between Profiles to experiment. + +- **Embedded Loras URL** : WanGP will now try to remember every Lora URLs it sees. For instance if someone sends you some settings that contain Loras URLs or you extract the Settings of Video generated by a friend with Loras URLs, these URLs will be automatically added to *WanGP URL Cache*. Conversely everything you will share (Videos, Settings, Lset files) will contain the download URLs if they are known. You can also download directly a Lora in WanGP by using the *Download Lora* button a the bottom. The Lora will be immediatly available and added to WanGP lora URL cache. This will work with *Hugging Face* as a repository. Support for CivitAi will come as soon as someone will nice enough to post a GitHub PR ... + +- **.lset file** supports embedded Loras URLs. It has never been easier to share a Lora with a friend. As a reminder a .lset file can be created directly from *WanGP Web Interface* and it contains a list of Loras and their multipliers, a Prompt and Instructions how to use these loras (like the Lora's *Trigger*). So with embedded Loras URL, you can send an .lset file by email or share it on discord: it is just a 1 KB tiny text, but with it other people will be able to use Gigabytes Loras as these will be automatically downloaded. + +I have created the new Discord Channel **share-your-settings** where you can post your *Settings* or *Lset files*. I will be pleased to add new Loras Accelerators in the list of WanGP *Accelerators Profiles if you post some good ones there. + +Last but not least the Lora's documentation has been updated. + ### September 30 2025: WanGP v8.9 - Combinatorics This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps. diff --git a/defaults/animate.json b/defaults/animate.json index bf45f4a92..d4676721a 100644 --- a/defaults/animate.json +++ b/defaults/animate.json @@ -12,6 +12,7 @@ [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors" ], + "settings_dir": [ "wan" ], "group": "wan2_2" } } \ No newline at end of file diff --git a/defaults/hunyuan_t2v_fast.json b/defaults/hunyuan_t2v_fast.json index acba28e20..30adb17a0 100644 --- a/defaults/hunyuan_t2v_fast.json +++ b/defaults/hunyuan_t2v_fast.json @@ -3,6 +3,7 @@ "name": "Hunyuan Video FastHunyuan 720p 13B", "architecture": "hunyuan", "description": "Fast Hunyuan is an accelerated HunyuanVideo model. It can sample high quality videos with 6 diffusion steps.", + "settings_dir": [ "" ], "URLs": [ "https://huggingface.co/DeepBeepMeep/HunyuanVideo/resolve/main/fast_hunyuan_video_720_quanto_int8.safetensors" ], diff --git a/defaults/i2v_fusionix.json b/defaults/i2v_fusionix.json index 851d6cc7a..8b0a8af54 100644 --- a/defaults/i2v_fusionix.json +++ b/defaults/i2v_fusionix.json @@ -5,6 +5,7 @@ "architecture" : "i2v", "description": "A powerful merged image-to-video model based on the original WAN 2.1 I2V model, enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", "URLs": "i2v", + "settings_dir": [ "" ], "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"] } } \ No newline at end of file diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json index a8f67ad64..57d3d958a 100644 --- a/defaults/lucy_edit.json +++ b/defaults/lucy_edit.json @@ -8,6 +8,7 @@ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors" ], + "settings_dir": "ti2v_2_2", "group": "wan2_2" }, "prompt": "change the clothes to red", diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json index d5d47c814..2a52166fc 100644 --- a/defaults/lucy_edit_fastwan.json +++ b/defaults/lucy_edit_fastwan.json @@ -5,6 +5,7 @@ "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", "URLs": "lucy_edit", "group": "wan2_2", + "settings_dir": [ "" ], "loras": "ti2v_2_2_fastwan" }, "prompt": "change the clothes to red", diff --git a/defaults/ti2v_2_2.json b/defaults/ti2v_2_2.json index ac329fa2d..91e906347 100644 --- a/defaults/ti2v_2_2.json +++ b/defaults/ti2v_2_2.json @@ -7,6 +7,7 @@ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_mbf16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_5B_quanto_mbf16_int8.safetensors" ], + "settings_dir": [ "wan2_2_5B" ], "group": "wan2_2" }, "video_length": 121, diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json index fa69f82ec..c142dcbd2 100644 --- a/defaults/ti2v_2_2_fastwan.json +++ b/defaults/ti2v_2_2_fastwan.json @@ -4,6 +4,7 @@ "architecture": "ti2v_2_2", "description": "FastWan2.2-TI2V-5B-Full-Diffusers is built upon Wan-AI/Wan2.2-TI2V-5B-Diffusers. It supports efficient 3-step inference and produces high-quality videos at 121×704×1280 resolution", "URLs": "ti2v_2_2", + "settings_dir": [ "" ], "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], "group": "wan2_2" }, diff --git a/defaults/vace_14B_cocktail.json b/defaults/vace_14B_cocktail.json index 87f2b78ad..0ab7a625e 100644 --- a/defaults/vace_14B_cocktail.json +++ b/defaults/vace_14B_cocktail.json @@ -7,6 +7,7 @@ ], "description": "This model has been created on the fly using the Wan text 2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. Copy the model def in the finetune folder to change the Cocktail composition.", "URLs": "t2v", + "settings_dir": [ "" ], "loras": [ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", diff --git a/defaults/vace_14B_cocktail_2_2.json b/defaults/vace_14B_cocktail_2_2.json index f821a9e50..6a831d214 100644 --- a/defaults/vace_14B_cocktail_2_2.json +++ b/defaults/vace_14B_cocktail_2_2.json @@ -14,6 +14,7 @@ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" ], + "settings_dir": [ "" ], "loras_multipliers": [1, 0.2, 0.5, 0.5], "group": "wan2_2" }, diff --git a/defaults/vace_14B_fusionix.json b/defaults/vace_14B_fusionix.json index 44c048c90..1c6fa7288 100644 --- a/defaults/vace_14B_fusionix.json +++ b/defaults/vace_14B_fusionix.json @@ -6,6 +6,7 @@ "vace_14B" ], "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", + "settings_dir": [ "" ], "URLs": "t2v_fusionix" }, "negative_prompt": "", diff --git a/docs/LORAS.md b/docs/LORAS.md index 89b2e5932..dc0caaa56 100644 --- a/docs/LORAS.md +++ b/docs/LORAS.md @@ -47,6 +47,20 @@ python wgp.py --lora-dir-hunyuan /path/to/hunyuan/loras --lora-dir-ltxv /path/to If you store loras in the loras folder once WanGP has been launched, click the *Refresh* button at the top so that it can become selectable. +### Autodownload of Loras +WanGP will try to remember where a Lora was obtained and will store the corresponding Download URL in the Generation settings that are embedded in the Generated Video. This is useful to share this information or to easily recover lost loras after a reinstall. + +This works very well if the Loras are stored in repositories such *Hugging Face* but won't work for the moment for Loras that requires a Login (like *CivitAi*) to be downloaded. + +WanGP will update its internal URL Lora Cache whenener one of this events will occur: +- when applying or importing a *Accelerator Profile*, *Settings* or *Lset* file that contains Loras with full URLs (not just local paths) +- when extracting the settings of a Video that was generated with Loras and that contained the full Loras URLs +- when downloading manually a Lora using the *Download Lora* button at the bottom + +So the more you use WanGP the more the URL cache File will get updated. The file is *loras_url_cache.json* and is located in the root folder of WanGP. + +You can delete this file with any risk if needed or share it with friends to save them time locating the Loras. You will need to restart WanGP if you modify manually this file or delete it. + ### Lora Multipliers Multipliers control the strength of each lora's effect: @@ -117,38 +131,49 @@ Loras Multipliers: 0;1;0 0;0;1 Here during the first phase with guidance 3.5, the High model will be used but there won't be any lora at all. Then during phase 2 only the High lora will be used (which requires to set the guidance to 1). At last in phase 3 WanGP will switch to the Low model and then only the Low lora will be used. *Note that the syntax for multipliers can also be used in a Finetune model definition file (except that each multiplier definition is a string in a json list)* -## Lora Presets +## Lora Presets (.Lset file) + +Lora Presets contains all the information needed to use a Lora or a combination of Loras: +- The full download URLs of the Loras +- Default Loras Multipliers +- Sample Prompt to use the Loras with their corresponding *Trigger Words* (usually as comments) +Optionaly they may contain advanced prompts with macros to generate automatically Prompt using keywords. -Lora Presets are combinations of loras with predefined multipliers and prompts. +A Lora Preset is a text file of only of few kilobytes and can be easily shared between users. Don't hesitate to use this format if you have created a Lora. ### Creating Presets 1. Configure your loras and multipliers -2. Write a prompt with comments (lines starting with #) -3. Save as a preset with `.lset` extension +2. Write a prompt with comments lines starting with # that contains instructions +3. Save as a preset with `.lset` extension by clicking the *Save* button at the top, select *Save Only Loras & Full Prompt* and finally click *Go Ahead Save it!* -### Example Preset +### Example Lora Preset Prompt ``` # Use the keyword "ohnvx" to trigger the lora A ohnvx character is driving a car through the city ``` -### Using Presets -```bash -# Load preset on startup -python wgp.py --lora-preset mypreset.lset +Using a macro (check the doc below), the user will just have to enter two words and the Prompt will be generated for him: ``` +! {Person}="man" : {Object}="car" +This {Person} is cleaning his {Object}. +``` + -### Managing Presets +### Managing Loras Presets (.lset Files) - Edit, save, or delete presets directly from the web interface - Presets include comments with usage instructions -- Share `.lset` files with other users +- Share `.lset` files with other users (make sure the full Loras URLs are in it) + +A *.lset* file may contain only local paths to the Loras if WanGP doesn't know where you got it. You can edit the .lset file with a text editor and replace the local path with its URL. If you store your Lora in Hugging Face, you can easily obtain its URL by selecting the file and clicking *Copy Download Link*. + +To share a *.Lset* file you will need (for the moment) to get it directly in the Lora folder where it is stored. ## Supported Formats -WanGP supports multiple lora formats: +WanGP supports multiple most lora formats: - **Safetensors** (.safetensors) - **Replicate** format -- **Standard PyTorch** (.pt, .pth) +- ... ## Loras Accelerators @@ -166,8 +191,17 @@ https://huggingface.co/DeepBeepMeep/Qwen_image/tree/main/loras_accelerators ### Setup Instructions -1. Download the Lora -2. Place it in your `loras/` directory if it is a t2v lora or in the `loras_i2v/` directory if it isa i2v lora +There are three ways to setup Loras accelerators: +1) **Finetune with Embedded Loras Accelerators** +Some models Finetunes such as *Vace FusioniX* or *Vace Coctail* have the Loras Accelerators already set up in their own definition and you won't have to do anything as they will be downloaded with the Finetune. + +2) **Accelerators Profiles** +Predefined *Accelerator Profiles* can be selected using the *Settings* Dropdown box at the top. The choices of Accelerators will depend on the models. No accelerator will be offered if the finetune / model is already accelerated. Just click *Apply* and the Accelerators Loras will be setup in the Loras tab at the bottom. Any missing Lora will be downloaded automatically the first time you try to generate a Video. Be aware that when applying an *Accelerator Profile*, inputs such as *Activated Loras*, *Number of Inference Steps*, ... will bo overwritten wheras other inputs will be preserved so that you can easily switch between profiles. + +3) **Manual Install** +- Download the Lora +- Place it in the Lora Directory of the correspondig model +- Configure the Loras Multipliers, CFG as described in the later sections ## FusioniX (or FusionX) Lora for Wan 2.1 / Wan 2.2 If you need just one Lora accelerator use this one. It is a combination of multiple Loras acelerators (including Causvid below) and style loras. It will not only accelerate the video generation but it will also improve the quality. There are two versions of this lora whether you use it for t2v or i2v diff --git a/loras/wan/Text2Video Cocktail - 10 Steps.json b/loras/wan/Text2Video Cocktail - 10 Steps.json new file mode 100644 index 000000000..087e49c92 --- /dev/null +++ b/loras/wan/Text2Video Cocktail - 10 Steps.json @@ -0,0 +1,13 @@ +{ + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": "1 0.5 0.5 0.5", + "num_inference_steps": 10, + "guidance_phases": 1, + "guidance_scale": 1, + "flow_shift": 2 +} diff --git a/loras/wan/Text2Video FusioniX - 10 Steps.json b/loras/wan/Text2Video FusioniX - 10 Steps.json new file mode 100644 index 000000000..9fabf09eb --- /dev/null +++ b/loras/wan/Text2Video FusioniX - 10 Steps.json @@ -0,0 +1,11 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 10, + "guidance_phases": 1, + "guidance_scale": 1, + "flow_shift": 2 +} diff --git a/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json b/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json new file mode 100644 index 000000000..e7aeedad9 --- /dev/null +++ b/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "model_switch_phase": 1, + "switch_threshold": 876, + "guidance_phases": 2, + "flow_shift": 3, + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" + ], + "loras_multipliers": "1;0 0;1" +} \ No newline at end of file diff --git a/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json b/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json new file mode 100644 index 000000000..681cf3077 --- /dev/null +++ b/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json @@ -0,0 +1,18 @@ +{ + "num_inference_steps": 8, + "guidance_scale": 3.5, + "guidance2_scale": 1, + "guidance3_scale": 1, + "switch_threshold": 965, + "switch_threshold2": 800, + "guidance_phases": 3, + "model_switch_phase": 2, + "flow_shift": 3, + "sample_solver": "euler", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" + ], + "loras_multipliers": "0;1;0 0;0;1", + "help": "This finetune uses the Lightning 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators." +} \ No newline at end of file diff --git a/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json b/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json new file mode 100644 index 000000000..56790c89f --- /dev/null +++ b/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "switch_threshold": 876, + "model_switch_phase": 1, + "guidance_phases": 2, + "flow_shift": 3, + "loras_multipliers": "1;0 0;1", + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + ] +} \ No newline at end of file diff --git a/loras/wan2_2/Text2Video FusioniX - 10 Steps.json b/loras/wan2_2/Text2Video FusioniX - 10 Steps.json new file mode 100644 index 000000000..3695bec63 --- /dev/null +++ b/loras/wan2_2/Text2Video FusioniX - 10 Steps.json @@ -0,0 +1,11 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2 +} diff --git a/loras/wan2_2_5B/FastWan 3 Steps.json b/loras/wan2_2_5B/FastWan 3 Steps.json new file mode 100644 index 000000000..1e0383c61 --- /dev/null +++ b/loras/wan2_2_5B/FastWan 3 Steps.json @@ -0,0 +1,7 @@ +{ + "activated_loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 3, + "loras_multipliers": "1" +} \ No newline at end of file diff --git a/loras_i2v/wan/Image2Video FusioniX - 10 Steps.json b/loras_i2v/wan/Image2Video FusioniX - 10 Steps.json new file mode 100644 index 000000000..1fe4c14e5 --- /dev/null +++ b/loras_i2v/wan/Image2Video FusioniX - 10 Steps.json @@ -0,0 +1,10 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 10, + "guidance_phases": 1, + "guidance_scale": 1 +} diff --git a/loras_qwen/qwen/Lightning v1.0 - 4 Steps.json b/loras_qwen/qwen/Lightning v1.0 - 4 Steps.json new file mode 100644 index 000000000..3dd322006 --- /dev/null +++ b/loras_qwen/qwen/Lightning v1.0 - 4 Steps.json @@ -0,0 +1,8 @@ +{ + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/loras_accelerators/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 4, + "guidance_scale": 1 +} diff --git a/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json b/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json new file mode 100644 index 000000000..42ae20d50 --- /dev/null +++ b/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json @@ -0,0 +1,8 @@ +{ + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/loras_accelerators/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 4, + "guidance_scale": 1 +} diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index de1212913..c8f845d07 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -493,7 +493,7 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): if settings_version < 2.32: image_prompt_type = ui_defaults.get("image_prompt_type", "") - if test_class_i2v(base_model_type) and len(image_prompt_type) == 0: + if test_class_i2v(base_model_type) and len(image_prompt_type) == 0 and "S" in model_def.get("image_prompt_types_allowed",""): ui_defaults["image_prompt_type"] = "S" diff --git a/shared/attention.py b/shared/attention.py index c31087c24..da860a9b6 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -68,6 +68,7 @@ def sageattn2_wrapper( return o try: + # from sageattn3 import sageattn3_blackwell as sageattn3 #word0 windows version from sageattn import sageattn_blackwell as sageattn3 except ImportError: sageattn3 = None diff --git a/wgp.py b/wgp.py index 74223a6c9..9c8309be0 100644 --- a/wgp.py +++ b/wgp.py @@ -63,7 +63,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.0" -WanGP_version = "8.9" +WanGP_version = "8.99" settings_version = 2.38 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -360,6 +360,12 @@ def ret(): if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting") + if len(activated_loras) > 0: + error = check_loras_exist(model_type, activated_loras) + if len(error) > 0: + gr.Info(error) + return ret() + if len(loras_multipliers) > 0: _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases) if len(errors) > 0: @@ -2089,10 +2095,10 @@ def get_transformer_dtype(model_family, transformer_dtype_policy): def get_settings_file_name(model_type): return os.path.join(args.settings, model_type + "_settings.json") -def fix_settings(model_type, ui_defaults): +def fix_settings(model_type, ui_defaults, min_settings_version = 0): if model_type is None: return - settings_version = ui_defaults.get("settings_version", 0) + settings_version = max(min_settings_version, ui_defaults.get("settings_version", 0)) model_def = get_model_def(model_type) base_model_type = get_base_model_type(model_type) @@ -2184,7 +2190,7 @@ def fix_settings(model_type, ui_defaults): model_handler = get_model_handler(base_model_type) if hasattr(model_handler, "fix_settings"): - model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults) + model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults) def get_default_settings(model_type): @@ -2525,6 +2531,27 @@ def download_mmaudio(): } process_files_def(**enhancer_def) + +def download_file(url,filename): + if url.startswith("https://huggingface.co/") and "/resolve/main/" in url: + base_dir = os.path.dirname(filename) + url = url[len("https://huggingface.co/"):] + url_parts = url.split("/resolve/main/") + repoId = url_parts[0] + onefile = os.path.basename(url_parts[-1]) + sourceFolder = os.path.dirname(url_parts[-1]) + if len(sourceFolder) == 0: + hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir) + else: + target_path = "ckpts/temp/" + sourceFolder + if not os.path.exists(target_path): + os.makedirs(target_path) + hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder) + shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir) + shutil.rmtree("ckpts/temp") + else: + urlretrieve(url,filename, create_progress_hook(filename)) + download_shared_done = False def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1): def computeList(filename): @@ -2574,26 +2601,6 @@ def computeList(filename): if model_filename is None: return - def download_file(url,filename): - if url.startswith("https://huggingface.co/") and "/resolve/main/" in url: - base_dir = os.path.dirname(filename) - url = url[len("https://huggingface.co/"):] - url_parts = url.split("/resolve/main/") - repoId = url_parts[0] - onefile = os.path.basename(url_parts[-1]) - sourceFolder = os.path.dirname(url_parts[-1]) - if len(sourceFolder) == 0: - hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir) - else: - target_path = "ckpts/temp/" + sourceFolder - if not os.path.exists(target_path): - os.makedirs(target_path) - hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder) - shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir) - shutil.rmtree("ckpts/temp") - else: - urlretrieve(url,filename, create_progress_hook(filename)) - base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) @@ -2653,6 +2660,33 @@ def download_file(url,filename): def sanitize_file_name(file_name, rep =""): return file_name.replace("/",rep).replace("\\",rep).replace(":",rep).replace("|",rep).replace("?",rep).replace("<",rep).replace(">",rep).replace("\"",rep).replace("\n",rep).replace("\r",rep) +def check_loras_exist(model_type, loras_choices_files, download = False, send_cmd = None): + lora_dir = get_lora_dir(model_type) + missing_local_loras = [] + missing_remote_loras = [] + for lora_file in loras_choices_files: + local_path = os.path.join(lora_dir, os.path.basename(lora_file)) + if not os.path.isfile(local_path): + url = loras_url_cache.get(local_path, None) + if url is not None: + if download: + if send_cmd is not None: + send_cmd("status", f'Downloading Lora {os.path.basename(lora_file)}...') + try: + download_file(url, local_path) + except: + missing_remote_loras.append(lora_file) + else: + missing_local_loras.append(lora_file) + + error = "" + if len(missing_local_loras) > 0: + error += f"The following Loras files are missing or invalid: {missing_local_loras}." + if len(missing_remote_loras) > 0: + error += f"The following Loras files could not be downloaded: {missing_remote_loras}." + + return error + def extract_preset(model_type, lset_name, loras): loras_choices = [] loras_choices_files = [] @@ -2669,24 +2703,12 @@ def extract_preset(model_type, lset_name, loras): if not os.path.isfile(lset_name_filename): error = f"Preset '{lset_name}' not found " else: - missing_loras = [] with open(lset_name_filename, "r", encoding="utf-8") as reader: text = reader.read() lset = json.loads(text) - loras_choices_files = lset["loras"] - for lora_file in loras_choices_files: - choice = os.path.join(lora_dir, lora_file) - if choice not in loras: - missing_loras.append(lora_file) - else: - loras_choice_no = loras.index(choice) - loras_choices.append(str(loras_choice_no)) - - if len(missing_loras) > 0: - error = f"Unable to apply Lora preset '{lset_name} because the following Loras files are missing or invalid: {missing_loras}" - + loras_choices = lset["loras"] loras_mult_choices = lset["loras_mult"] prompt = lset.get("prompt", "") full_prompt = lset.get("full_prompt", False) @@ -2695,7 +2717,6 @@ def extract_preset(model_type, lset_name, loras): def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, split_linear_modules_map = None): loras =[] - loras_names = [] default_loras_choices = [] default_loras_multis_str = "" loras_presets = [] @@ -2726,7 +2747,7 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl loras = offload.load_loras_into_model(transformer, loras, activate_all_loras=False, check_only= True, preprocess_sd=get_loras_preprocessor(transformer, model_type), split_linear_modules_map = split_linear_modules_map) #lora_multiplier, if len(loras) > 0: - loras_names = [ Path(lora).stem for lora in loras ] + loras = [ os.path.basename(lora) for lora in loras ] if len(lora_preselected_preset) > 0: if not os.path.isfile(os.path.join(lora_dir, lora_preselected_preset + ".lset")): @@ -2735,7 +2756,7 @@ def setup_loras(model_type, transformer, lora_dir, lora_preselected_preset, spl default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, _ , error = extract_preset(model_type, default_lora_preset, loras) if len(error) > 0: print(error[:200]) - return loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset + return loras, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset def get_transformer_model(model, submodel_no = 1): if submodel_no > 1: @@ -2996,7 +3017,7 @@ def apply_changes( state, return "
Config Locked
",*[gr.update()]*4 if gen_in_progress: return "
Unable to change config when a generation is in progress
",*[gr.update()]*4 - global offloadobj, wan_model, server_config, loras, loras_names, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets + global offloadobj, wan_model, server_config, loras, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = { "attention_mode" : attention_choice, "transformer_types": transformer_types_choices, @@ -3464,6 +3485,8 @@ def check(src, cond): video_loras_multipliers = configs.get("loras_multipliers", "") video_loras_multipliers = preparse_loras_multipliers(video_loras_multipliers) video_loras_multipliers += [""] * len(video_activated_loras) + + video_activated_loras = [ f"{os.path.basename(lora)}{lora}" for lora in video_activated_loras] video_activated_loras = [ f"{lora}x{multiplier if len(multiplier)>0 else '1'}" for lora, multiplier in zip(video_activated_loras, video_loras_multipliers) ] video_activated_loras_str = "" + "".join(video_activated_loras) + "
" if len(video_activated_loras) > 0 else "" values += misc_values + [video_prompt] @@ -4681,7 +4704,9 @@ def remove_temp_filenames(temp_filenames_list): loras_list_mult_choices_nums, loras_slists, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases = guidance_phases, merge_slist= loras_slists ) if len(errors) > 0: raise Exception(f"Error parsing Loras: {errors}") lora_dir = get_lora_dir(model_type) - loras_selected += [ os.path.join(lora_dir, lora) for lora in activated_loras] + errors = check_loras_exist(model_type, activated_loras, True, send_cmd) + if len(errors) > 0 : raise gr.Error(errors) + loras_selected += [ os.path.join(lora_dir, os.path.basename(lora)) for lora in activated_loras] if len(loras_selected) > 0: pinnedLora = loaded_profile !=5 # and transformer_loras_filenames == None False # # # @@ -4699,6 +4724,9 @@ def remove_temp_filenames(temp_filenames_list): original_filename = model_filename model_filename = get_model_filename(base_model_type) + _, _, latent_size = get_model_min_frames_and_step(model_type) + video_length = (video_length -1) // latent_size * latent_size + 1 + current_video_length = video_length # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 @@ -4740,13 +4768,12 @@ def remove_temp_filenames(temp_filenames_list): if video_source is not None: current_video_length += sliding_window_overlap - 1 sliding_window = current_video_length > sliding_window_size - reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) + reuse_frames = min(sliding_window_size - latent_size, sliding_window_overlap) else: sliding_window = False sliding_window_size = current_video_length reuse_frames = 0 - _, _, latent_size = get_model_min_frames_and_step(model_type) original_image_refs = image_refs image_refs = None if image_refs is None else ([] + image_refs) # work on a copy as it is going to be modified # image_refs = None @@ -5809,8 +5836,24 @@ def get_new_preset_msg(advanced = True): else: return "Choose a Lora Preset or a Settings file in this List" -def compute_lset_choices(loras_presets): - # lset_choices = [ (preset, preset) for preset in loras_presets] +def compute_lset_choices(model_type, loras_presets): + # top_dir = "settings" + global_list = [] + if model_type is not None: + top_dir = get_lora_dir(model_type) + model_def = get_model_def(model_type) + settings_dir = get_model_recursive_prop(model_type, "settings_dir", return_list=False) + if settings_dir is None or len(settings_dir) == 0: settings_dir = [get_model_family(model_type, True)] + for dir in settings_dir: + if len(dir) == "": continue + cur_path = os.path.join(top_dir, dir) + if os.path.isdir(cur_path): + cur_dir_presets = glob.glob( os.path.join(cur_path, "*.json") ) + cur_dir_presets =[ os.path.join(dir, os.path.basename(path)) for path in cur_dir_presets] + global_list += cur_dir_presets + global_list = sorted(global_list, key=lambda n: os.path.basename(n)) + + lset_list = [] settings_list = [] for item in loras_presets: @@ -5822,6 +5865,9 @@ def compute_lset_choices(loras_presets): sep = '\u2500' indent = chr(160) * 4 lset_choices = [] + if len(global_list) > 0: + lset_choices += [( (sep*16) +"Profiles" + (sep*17), ">Profiles")] + lset_choices += [ ( indent + os.path.splitext(os.path.basename(preset))[0], preset) for preset in global_list ] if len(settings_list) > 0: settings_list.sort() lset_choices += [( (sep*16) +"Settings" + (sep*17), ">settings")] @@ -5836,16 +5882,20 @@ def get_lset_name(state, lset_name): presets = state["loras_presets"] if len(lset_name) == 0 or lset_name.startswith(">") or lset_name== get_new_preset_msg(True) or lset_name== get_new_preset_msg(False): return "" if lset_name in presets: return lset_name - choices = compute_lset_choices(presets) + model_type = state["model_type"] + choices = compute_lset_choices(model_type, presets) for label, value in choices: if label == lset_name: return value return lset_name def validate_delete_lset(state, lset_name): lset_name = get_lset_name(state, lset_name) - if len(lset_name) == 0: + if len(lset_name) == 0 : gr.Info(f"Choose a Preset to delete") return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False) + elif "/" in lset_name or "\\" in lset_name: + gr.Info(f"You can't Delete a Profile") + return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False) else: return gr.Button(visible= False), gr.Checkbox(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True) @@ -5854,6 +5904,9 @@ def validate_save_lset(state, lset_name): if len(lset_name) == 0: gr.Info("Please enter a name for the preset") return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False) + elif "/" in lset_name or "\\" in lset_name: + gr.Info(f"You can't Edit a Profile") + return gr.Button(visible= True), gr.Checkbox(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False),gr.Checkbox(visible= False) else: return gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= False), gr.Button(visible= True), gr.Button(visible= True),gr.Checkbox(visible= True) @@ -5882,8 +5935,7 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ lset = collect_current_model_settings(state) extension = ".json" else: - loras_choices_files = [ Path(loras[int(choice_no)]).parts[-1] for choice_no in loras_choices ] - lset = {"loras" : loras_choices_files, "loras_mult" : loras_mult_choices} + lset = {"loras" : loras_choices, "loras_mult" : loras_mult_choices} if save_lset_prompt_cbox!=1: prompts = prompt.replace("\r", "").split("\n") prompts = [prompt for prompt in prompts if len(prompt)> 0 and prompt.startswith("#")] @@ -5900,7 +5952,8 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ if not old_lset_name in loras_presets: old_lset_name = "" lset_name = lset_name + extension - lora_dir = get_lora_dir(state["model_type"]) + model_type = state["model_type"] + lora_dir = get_lora_dir(model_type) full_lset_name_filename = os.path.join(lora_dir, lset_name ) with open(full_lset_name_filename, "w", encoding="utf-8") as writer: @@ -5923,20 +5976,21 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ loras_presets.append(lset_name) state["loras_presets"] = loras_presets - lset_choices = compute_lset_choices(loras_presets) + lset_choices = compute_lset_choices(model_type, loras_presets) lset_choices.append( (get_new_preset_msg(), "")) return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Button(visible= False), gr.Checkbox(visible= False) def delete_lset(state, lset_name): loras_presets = state["loras_presets"] lset_name = get_lset_name(state, lset_name) + model_type = state["model_type"] if len(lset_name) > 0: - lset_name_filename = os.path.join( get_lora_dir(state["model_type"]), sanitize_file_name(lset_name)) + lset_name_filename = os.path.join( get_lora_dir(model_type), sanitize_file_name(lset_name)) if not os.path.isfile(lset_name_filename): gr.Info(f"Preset '{lset_name}' not found ") return [gr.update()]*7 os.remove(lset_name_filename) - lset_choices = compute_lset_choices(loras_presets) + lset_choices = compute_lset_choices(None, loras_presets) pos = next( (i for i, item in enumerate(lset_choices) if item[1]==lset_name ), -1) gr.Info(f"Lora Preset '{lset_name}' has been deleted") loras_presets.remove(lset_name) @@ -5946,30 +6000,33 @@ def delete_lset(state, lset_name): state["loras_presets"] = loras_presets - lset_choices = compute_lset_choices(loras_presets) + lset_choices = compute_lset_choices(model_type, loras_presets) lset_choices.append((get_new_preset_msg(), "")) selected_lset_name = "" if pos < 0 else lset_choices[min(pos, len(lset_choices)-1)][1] return gr.Dropdown(choices=lset_choices, value= selected_lset_name), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= True), gr.Button(visible= False), gr.Checkbox(visible= False) +def get_updated_loras_dropdown(loras, loras_choices): + loras_choices = [os.path.basename(choice) for choice in loras_choices] + loras_choices_dict = { choice : True for choice in loras_choices} + for lora in loras: + loras_choices_dict.pop(lora, False) + new_loras = loras[:] + for choice, _ in loras_choices_dict.items(): + new_loras.append(choice) + + new_loras_dropdown= [ ( os.path.splitext(choice)[0], choice) for choice in new_loras ] + return new_loras, new_loras_dropdown + def refresh_lora_list(state, lset_name, loras_choices): - loras_names = state["loras_names"] - prev_lora_names_selected = [ loras_names[int(i)] for i in loras_choices] model_type= state["model_type"] - loras, loras_names, loras_presets, _, _, _, _ = setup_loras(model_type, None, get_lora_dir(model_type), lora_preselected_preset, None) - state["loras"] = loras - state["loras_names"] = loras_names + loras, loras_presets, _, _, _, _ = setup_loras(model_type, None, get_lora_dir(model_type), lora_preselected_preset, None) state["loras_presets"] = loras_presets - gc.collect() - new_loras_choices = [ (loras_name, str(i)) for i,loras_name in enumerate(loras_names)] - new_loras_dict = { loras_name: str(i) for i,loras_name in enumerate(loras_names) } - lora_names_selected = [] - for lora in prev_lora_names_selected: - lora_id = new_loras_dict.get(lora, None) - if lora_id!= None: - lora_names_selected.append(lora_id) - - lset_choices = compute_lset_choices(loras_presets) + + loras, new_loras_dropdown = get_updated_loras_dropdown(loras, loras_choices) + state["loras"] = loras + model_type = state["model_type"] + lset_choices = compute_lset_choices(model_type, loras_presets) lset_choices.append((get_new_preset_msg( state["advanced"]), "")) if not lset_name in loras_presets: lset_name = "" @@ -5983,7 +6040,7 @@ def refresh_lora_list(state, lset_name, loras_choices): gr.Info("Lora List has been refreshed") - return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Dropdown(choices=new_loras_choices, value= lora_names_selected) + return gr.Dropdown(choices=lset_choices, value= lset_name), gr.Dropdown(choices=new_loras_dropdown, value= loras_choices) def update_lset_type(state, lset_name): return 1 if lset_name.endswith(".lset") else 2 @@ -6002,30 +6059,31 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m if lset_name.endswith(".lset"): loras = state["loras"] loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(current_model_type, lset_name, loras) - if len(error) > 0: - gr.Info(error) - else: - if full_prompt: - prompt = preset_prompt - elif len(preset_prompt) > 0: - prompts = prompt.replace("\r", "").split("\n") - prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")] - prompt = "\n".join(prompts) - prompt = preset_prompt + '\n' + prompt - gr.Info(f"Lora Preset '{lset_name}' has been applied") - state["apply_success"] = 1 - wizard_prompt_activated = "on" + loras_choices = update_loras_url_cache(get_lora_dir(current_model_type), loras_choices) + if full_prompt: + prompt = preset_prompt + elif len(preset_prompt) > 0: + prompts = prompt.replace("\r", "").split("\n") + prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")] + prompt = "\n".join(prompts) + prompt = preset_prompt + '\n' + prompt + gr.Info(f"Lora Preset '{lset_name}' has been applied") + state["apply_success"] = 1 + wizard_prompt_activated = "on" return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, get_unique_id(), gr.update(), gr.update(), gr.update() else: - configs, _ = get_settings_from_file(state, os.path.join(get_lora_dir(current_model_type), lset_name), True, True, True) + # lset_path = os.path.join("settings", lset_name) if len(Path(lset_name).parts)>1 else os.path.join(get_lora_dir(current_model_type), lset_name) + lset_path = os.path.join(get_lora_dir(current_model_type), lset_name) + configs, _ = get_settings_from_file(state,lset_path , True, True, True, min_settings_version=2.38) if configs == None: gr.Info("File not supported") - return [gr.update()] * 7 - + return [gr.update()] * 8 model_type = configs["model_type"] configs["lset_name"] = lset_name gr.Info(f"Settings File '{lset_name}' has been applied") + help = configs.get("help", None) + if help is not None: gr.Info(help) if model_type == current_model_type: set_model_settings(state, current_model_type, configs) @@ -6162,7 +6220,8 @@ def switch_prompt_type(state, wizard_prompt_activated_var, wizard_variables_name def switch_advanced(state, new_advanced, lset_name): state["advanced"] = new_advanced loras_presets = state["loras_presets"] - lset_choices = compute_lset_choices(loras_presets) + model_type = state["model_type"] + lset_choices = compute_lset_choices(model_type, loras_presets) lset_choices.append((get_new_preset_msg(new_advanced), "")) server_config["last_advanced_choice"] = new_advanced with open(server_config_filename, "w", encoding="utf-8") as writer: @@ -6184,9 +6243,12 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if "loras_choices" in inputs: loras_choices = inputs.pop("loras_choices") inputs.pop("model_filename", None) - activated_loras = [Path( loras[int(no)]).parts[-1] for no in loras_choices ] - inputs["activated_loras"] = activated_loras - + else: + loras_choices = inputs["activated_loras"] + if model_type == None: model_type = state["model_type"] + + inputs["activated_loras"] = update_loras_url_cache(get_lora_dir(model_type), loras_choices) + if target == "state": return inputs @@ -6196,7 +6258,6 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None unsaved_params = ["image_start", "image_end", "image_refs", "video_guide", "image_guide", "video_source", "video_mask", "image_mask", "audio_guide", "audio_guide2", "audio_source"] for k in unsaved_params: inputs.pop(k) - if model_type == None: model_type = state["model_type"] inputs["type"] = get_model_record(get_model_name(model_type)) inputs["settings_version"] = settings_version model_def = get_model_def(model_type) @@ -6543,8 +6604,38 @@ def use_video_settings(state, input_file_list, choice): gr.Info(f"No Video is Selected") return gr.update(), gr.update() - -def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, switch_type_if_compatible): +loras_url_cache = None +def update_loras_url_cache(lora_dir, loras_selected): + if loras_selected is None: return None + global loras_url_cache + loras_cache_file = "loras_url_cache.json" + if loras_url_cache is None: + if os.path.isfile(loras_cache_file): + try: + with open(loras_cache_file, 'r', encoding='utf-8') as f: + loras_url_cache = json.load(f) + except: + loras_url_cache = {} + else: + loras_url_cache = {} + new_loras_selected = [] + update = False + for lora in loras_selected: + base_name = os.path.basename(lora) + local_name = os.path.join(lora_dir, base_name) + url = loras_url_cache.get(local_name, base_name) + if lora.startswith("http:") or lora.startswith("https:") and url != lora: + loras_url_cache[local_name]=lora + update = True + new_loras_selected.append(url) + + if update: + with open(loras_cache_file, "w", encoding="utf-8") as writer: + writer.write(json.dumps(loras_url_cache, indent=4)) + + return new_loras_selected + +def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, switch_type_if_compatible, min_settings_version = 0): configs = None any_image_or_video = False if file_path.endswith(".json") and allow_json: @@ -6570,7 +6661,10 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw pass if configs is None: return None, False try: - if not "WanGP" in configs.get("type", ""): configs = None + if isinstance(configs, dict): + if (not merge_with_defaults) and not "WanGP" in configs.get("type", ""): configs = None + else: + configs = None except: configs = None if configs is None: return None, False @@ -6590,7 +6684,6 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw model_type = current_model_type elif not model_type in model_types: model_type = current_model_type - fix_settings(model_type, configs) if switch_type_if_compatible and are_model_types_compatible(model_type,current_model_type): model_type = current_model_type if merge_with_defaults: @@ -6598,6 +6691,12 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw defaults = get_default_settings(model_type) if defaults == None else defaults defaults.update(configs) configs = defaults + + loras_selected =configs.get("activated_loras", None) + if loras_selected is not None: + configs["activated_loras"]= update_loras_url_cache(get_lora_dir(model_type), loras_selected) + + fix_settings(model_type, configs, min_settings_version) configs["model_type"] = model_type return configs, any_image_or_video @@ -6653,6 +6752,13 @@ def load_settings_from_file(state, file_path): set_model_settings(state, model_type, configs) return *generate_dropdown_model_list(model_type), gr.update(), None +def reset_settings(state): + model_type = state["model_type"] + ui_defaults = get_default_settings(model_type) + set_model_settings(state, model_type, ui_defaults) + gr.Info(f"Default Settings have been Restored") + return str(time.time()) + def save_inputs( target, image_mask_guide, @@ -7348,6 +7454,25 @@ def get_default_value(choices, current_value, default_value = None): return current_value return default_value +def download_lora(state, lora_url, progress=gr.Progress(track_tqdm=True),): + if lora_url is None or not lora_url.startswith("http"): + gr.Info("Please provide a URL for a Lora to Download") + return gr.update() + model_type = state["model_type"] + lora_short_name = os.path.basename(lora_url) + lora_dir = get_lora_dir(model_type) + local_path = os.path.join(lora_dir, lora_short_name) + try: + download_file(lora_url, local_path) + except Exception as e: + gr.Info(f"Error downloading Lora {lora_short_name}: {e}") + return gr.update() + update_loras_url_cache(lora_dir, [lora_url]) + gr.Info(f"Lora {lora_short_name} has been succesfully downloaded") + return "" + + + def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): global inputs_names #, advanced @@ -7375,11 +7500,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non model_filename = get_model_filename( base_model_type ) preset_to_load = lora_preselected_preset if lora_preset_model == model_type else "" - loras, loras_names, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(model_type, None, get_lora_dir(model_type), preset_to_load, None) + loras, loras_presets, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset = setup_loras(model_type, None, get_lora_dir(model_type), preset_to_load, None) state_dict["loras"] = loras state_dict["loras_presets"] = loras_presets - state_dict["loras_names"] = loras_names launch_prompt = "" launch_preset = "" @@ -7400,17 +7524,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non launch_prompt = ui_defaults.get("prompt","") if len(launch_loras) == 0: launch_multis_str = ui_defaults.get("loras_multipliers","") - activated_loras = ui_defaults.get("activated_loras",[]) - if len(activated_loras) > 0: - lora_filenames = [os.path.basename(lora_path) for lora_path in loras] - activated_indices = [] - for lora_file in ui_defaults["activated_loras"]: - try: - idx = lora_filenames.index(lora_file) - activated_indices.append(str(idx)) - except ValueError: - print(f"Warning: Lora file {lora_file} from config not found in loras directory") - launch_loras = activated_indices + launch_loras = [os.path.basename(path) for path in ui_defaults.get("activated_loras",[])] with gr.Row(): with gr.Column(): @@ -7421,7 +7535,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non modal_image_display = gr.HTML(label="Full Resolution Image") preview_column_no = gr.Text(visible=False, value=-1, elem_id="preview_column_no") with gr.Row(visible= True): #len(loras)>0) as presets_column: - lset_choices = compute_lset_choices(loras_presets) + [(get_new_preset_msg(advanced_ui), "")] + lset_choices = compute_lset_choices(model_type, loras_presets) + [(get_new_preset_msg(advanced_ui), "")] with gr.Column(scale=6): lset_name = gr.Dropdown(show_label=False, allow_custom_value= True, scale=5, filterable=True, choices= lset_choices, value=launch_preset) with gr.Column(scale=1): @@ -7449,8 +7563,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non delete_lset_btn = gr.Button("Delete", size="sm", min_width= 1, visible = True) cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) #confirm_save_lset_btn, confirm_delete_lset_btn, save_lset_btn, delete_lset_btn, cancel_lset_btn - if not update_form: - state = gr.State(state_dict) trigger_refresh_input_type = gr.Text(interactive= False, visible= False) t2v = base_model_type in ["t2v"] t2v_1_3B = base_model_type in ["t2v_1.3B"] @@ -7967,13 +8079,14 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl with gr.Tab("Loras"): with gr.Column(visible = True): #as loras_column: gr.Markdown("Loras can be used to create special effects on the video by mentioning a trigger word in the Prompt. You can save Loras combinations in presets.") + loras, loras_choices = get_updated_loras_dropdown(loras, launch_loras) + state_dict["loras"] = loras loras_choices = gr.Dropdown( - choices=[ - (lora_name, str(i) ) for i, lora_name in enumerate(loras_names) - ], + choices=loras_choices, value= launch_loras, multiselect= True, - label="Activated Loras" + label="Activated Loras", + allow_custom_value= True, ) loras_multipliers = gr.Textbox(label="Loras Multipliers (1.0 by default) separated by Space chars or CR, lines that start with # are ignored", value=launch_multis_str) with gr.Tab("Steps Skipping", visible = any_tea_cache or any_mag_cache) as speed_tab: @@ -8259,15 +8372,21 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai with gr.Row(): save_settings_btn = gr.Button("Set Settings as Default", visible = not args.lock_config) export_settings_from_file_btn = gr.Button("Export Settings to File") + reset_settings_btn = gr.Button("Reset Settings") with gr.Row(): settings_file = gr.File(height=41,label="Load Settings From Video / Image / JSON") settings_base64_output = gr.Text(interactive= False, visible=False, value = "") settings_filename = gr.Text(interactive= False, visible=False, value = "") - + with gr.Group(): + with gr.Row(): + lora_url = gr.Text(label ="Lora URL", placeholder= "Enter Lora URL", scale=4, show_label=False, elem_classes="compact_text" ) + download_lora_btn = gr.Button("Download Lora", scale=1, min_width=10) + mode = gr.Text(value="", visible = False) with gr.Column(): if not update_form: + state = gr.State(state_dict) gen_status = gr.Text(interactive= False, label = "Status") status_trigger = gr.Text(interactive= False, visible=False) default_files = [] @@ -8423,6 +8542,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab], show_progress="hidden") preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview], show_progress="hidden") PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] ) + download_lora_btn.click(fn=download_lora, inputs = [state, lora_url], outputs = [lora_url]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) def refresh_status_async(state, progress=gr.Progress()): gen = get_gen_info(state) gen["progress"] = progress @@ -8563,6 +8683,17 @@ def activate_status(state): ).then(fn=load_settings_from_file, inputs =[state, settings_file] , outputs= [model_family, model_choice, refresh_form_trigger, settings_file]) + reset_settings_btn.click(fn=validate_wizard_prompt, + inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , + outputs= [prompt], + show_progress="hidden", + ).then(fn=save_inputs, + inputs =[target_state] + gen_inputs, + outputs= None + ).then(fn=reset_settings, inputs =[state] , outputs= [refresh_form_trigger]) + + + fill_wizard_prompt_trigger.change( fn = fill_wizard_prompt, inputs = [state, wizard_prompt_activated_var, prompt, wizard_prompt], outputs = [ wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars] ) @@ -9678,6 +9809,26 @@ def create_ui(): } .btn_centered {margin-top:10px; text-wrap-mode: nowrap;} .cbx_centered label {margin-top:8px; text-wrap-mode: nowrap;} + + .copy-swap { display:block; max-width:100%; + } + .copy-swap__trunc { + display: block; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + } + .copy-swap__full { display: none; } + .copy-swap:hover .copy-swap__trunc, + .copy-swap:focus-within .copy-swap__trunc { display: none; } + .copy-swap:hover .copy-swap__full, + .copy-swap:focus-within .copy-swap__full { + display: inline; + white-space: normal; + user-select: text; + } + .compact_text {padding: 1px} + """ UI_theme = server_config.get("UI_theme", "default") UI_theme = args.theme if len(args.theme) > 0 else UI_theme From f58dfd70ba9c1c7861604e68fb7f8a02b6ea6f33 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 2 Oct 2025 21:33:21 +0200 Subject: [PATCH 062/155] fixed typos in profiles --- loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json | 4 ++-- loras_qwen/qwen/Lightning v1.1 - 8 Steps.json | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json b/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json index 56790c89f..31e4f0d66 100644 --- a/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json +++ b/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json @@ -8,7 +8,7 @@ "flow_shift": 3, "loras_multipliers": "1;0 0;1", "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" ] } \ No newline at end of file diff --git a/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json b/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json index 42ae20d50..98909f6d8 100644 --- a/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json +++ b/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json @@ -3,6 +3,6 @@ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/loras_accelerators/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors" ], "loras_multipliers": "1", - "num_inference_steps": 4, + "num_inference_steps": 8, "guidance_scale": 1 } From c42c6fe9d1fc57312812e4c34162113d572128eb Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 2 Oct 2025 22:48:14 +0200 Subject: [PATCH 063/155] fix video ui trim --- wgp.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/wgp.py b/wgp.py index 9c8309be0..5bf05582c 100644 --- a/wgp.py +++ b/wgp.py @@ -7654,7 +7654,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl else: image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1) image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new Videos in the Generation Queue", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) + video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None), elem_id="video_input") image_end_row, image_end, image_end_extra = get_image_gallery(label= get_image_end_label(ui_defaults.get("multi_prompts_gen_type", 0)), value = ui_defaults.get("image_end", None), visible= any_letters(image_prompt_type_value, "SVL") and ("E" in image_prompt_type_value) ) if model_mode_choices is None or image_mode_value not in model_modes_visibility: model_mode = gr.Dropdown(value=None, visible=False) @@ -7814,7 +7814,7 @@ def get_image_gallery(label ="", value = None, single_image_mode = False, visibl ) image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) - video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None), elem_id="video_input") if image_mode_value >= 1: image_guide_value = ui_defaults.get("image_guide", None) image_mask_value = ui_defaults.get("image_mask", None) @@ -9828,7 +9828,13 @@ def create_ui(): user-select: text; } .compact_text {padding: 1px} - + #video_input .video-container .container { + flex-direction: row !important; + display: flex !important; + } + #video_input .video-container .container .controls { + overflow: visible !important; + } """ UI_theme = server_config.get("UI_theme", "default") UI_theme = args.theme if len(args.theme) > 0 else UI_theme From 3da6af4859e05c5ecbc660cab616b12231303499 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 3 Oct 2025 00:42:07 +0200 Subject: [PATCH 064/155] add t2v Lighx2v profile --- loras/wan/Text2Video FusioniX - 10 Steps.json | 4 ++-- ...eo Lightx2v Cfg Step Distill Rank32 - 4 Steps.json | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) create mode 100644 loras/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json diff --git a/loras/wan/Text2Video FusioniX - 10 Steps.json b/loras/wan/Text2Video FusioniX - 10 Steps.json index 9fabf09eb..cba907b00 100644 --- a/loras/wan/Text2Video FusioniX - 10 Steps.json +++ b/loras/wan/Text2Video FusioniX - 10 Steps.json @@ -1,10 +1,10 @@ { "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors" ], "loras_multipliers": "1", - "num_inference_steps": 10, + "num_inference_steps": 4, "guidance_phases": 1, "guidance_scale": 1, "flow_shift": 2 diff --git a/loras/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json b/loras/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json new file mode 100644 index 000000000..8ed2a4db0 --- /dev/null +++ b/loras/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json @@ -0,0 +1,11 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 4, + "guidance_phases": 1, + "guidance_scale": 1, + "flow_shift": 2 +} \ No newline at end of file From 0793fece6b8fc2057ba1ee87dfa0241e96ff401a Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 3 Oct 2025 01:22:59 +0200 Subject: [PATCH 065/155] fixed Animate bug with boot disabled --- models/wan/any2video.py | 3 ++- models/wan/modules/model.py | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 5cf1e712a..6c4550944 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -929,7 +929,8 @@ def clear(): gen_args = { "x" : [latent_model_input, latent_model_input], "context" : [context, context_null], - "face_pixel_values": [face_pixel_values, None] + # "face_pixel_values": [face_pixel_values, None] + "face_pixel_values": [face_pixel_values, face_pixel_values] # seems to look better this way } elif lynx: gen_args = { diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 30f6172f0..6c223c833 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -1354,8 +1354,7 @@ def forward( # animate embeddings motion_vec = None if pose_latents is not None: - # x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values) - x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values[0]) # if one_face_pixel_values is None else one_face_pixel_values) + x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values) motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) From 45017cfa94f2081c3d49e2c0c5154e29e97dc0ce Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 3 Oct 2025 12:31:05 +0200 Subject: [PATCH 066/155] fixed accelerator profiles mixed up --- loras/wan/Text2Video FusioniX - 10 Steps.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/loras/wan/Text2Video FusioniX - 10 Steps.json b/loras/wan/Text2Video FusioniX - 10 Steps.json index cba907b00..9fabf09eb 100644 --- a/loras/wan/Text2Video FusioniX - 10 Steps.json +++ b/loras/wan/Text2Video FusioniX - 10 Steps.json @@ -1,10 +1,10 @@ { "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors" + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" ], "loras_multipliers": "1", - "num_inference_steps": 4, + "num_inference_steps": 10, "guidance_phases": 1, "guidance_scale": 1, "flow_shift": 2 From bb70ca7d1494523e0e3f0a68edf5989288c0a6bd Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 3 Oct 2025 19:21:15 +0200 Subject: [PATCH 067/155] Vace Lynx Unleashed --- README.md | 3 +- defaults/vace_lynx_14B.json | 2 +- defaults/vace_standin_14B.json | 2 +- models/wan/any2video.py | 110 +++++++++++++++++---------------- models/wan/modules/model.py | 39 +++++++++--- models/wan/wan_handler.py | 16 +++-- shared/utils/utils.py | 9 ++- wgp.py | 3 +- 8 files changed, 109 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index 637b1640d..db25ed6d0 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### October 2 2025: WanGP v8.99 - One Last Thing before the Big Unknown ... +### October 2 2025: WanGP v8.991 - One Last Thing before the Big Unknown ... This new version hasn't any new model... @@ -35,6 +35,7 @@ I have created the new Discord Channel **share-your-settings** where you can pos Last but not least the Lora's documentation has been updated. +*update 8.991*: full power of *Vace Lynx* unleashed with new combinations such as Landscape + Face / Clothes + Face / Injectd Frame (Start/End frames/...) + Face ### September 30 2025: WanGP v8.9 - Combinatorics This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps. diff --git a/defaults/vace_lynx_14B.json b/defaults/vace_lynx_14B.json index 523d5b3fc..3584aa99f 100644 --- a/defaults/vace_lynx_14B.json +++ b/defaults/vace_lynx_14B.json @@ -3,7 +3,7 @@ "name": "Vace Lynx 14B", "architecture": "vace_lynx_14B", "modules": [ "vace_14B", "lynx"], - "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Lynx version is specialized in transferring a Person Identity.", + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Lynx version is specialized in identity transfer, so the last Image Ref should always contain a close up of the Face of a Person to transfer.", "URLs": "t2v", "preload_URLs": "lynx" } diff --git a/defaults/vace_standin_14B.json b/defaults/vace_standin_14B.json index 6fc8a6868..b6f6af099 100644 --- a/defaults/vace_standin_14B.json +++ b/defaults/vace_standin_14B.json @@ -3,7 +3,7 @@ "name": "Vace Standin 14B", "architecture": "vace_standin_14B", "modules": [ "vace_14B", "standin"], - "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video.", + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Standin version is specialized in identity transfer, so the last Image Ref should always contain a close up of the Face of a Person to transfer.", "URLs": "t2v" } } \ No newline at end of file diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 6c4550944..243dddb4b 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -686,13 +686,68 @@ def generate(self, if extended_input_dim > 0: extended_latents[:, :, :source_latents.shape[2]] = source_latents + # Lynx + if lynx : + if original_input_ref_images is None or len(original_input_ref_images) == 0: + lynx = False + elif "K" in video_prompt_type and len(input_ref_images) <= 1: + print("Warning: Missing Lynx Ref Image, make sure 'Inject only People / Objets' is selected or if there is 'Landscape and then People or Objects' there are at least two ref images (one Landscape image followed by face).") + lynx = False + else: + from .lynx.resampler import Resampler + from accelerate import init_empty_weights + lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"] + ip_hidden_states = ip_hidden_states_uncond = None + if True: + with init_empty_weights(): + arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 ) + offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) + arc_resampler.to(self.device) + arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float) + ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype) + ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype) + arc_resampler = None + if not lynx_lite: + standin_ref_pos = -1 + image_ref = original_input_ref_images[standin_ref_pos] + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + lynx_ref = face_processor.process(image_ref, resize_to = 256 ) + lynx_ref_buffer, lynx_ref_buffer_uncond = self.encode_reference_images([lynx_ref], tile_size=VAE_tile_size, any_guidance= any_guidance_at_all) + lynx_ref = None + gc.collect() + torch.cuda.empty_cache() + vace_lynx = model_type in ["vace_lynx_14B"] + kwargs["lynx_ip_scale"] = control_scale_alt + kwargs["lynx_ref_scale"] = control_scale_alt + + #Standin + if standin: + from preprocessing.face_preprocessor import FaceProcessor + standin_ref_pos = 1 if "K" in video_prompt_type else 0 + if len(original_input_ref_images) < standin_ref_pos + 1: + if "I" in video_prompt_type and model_type in ["vace_standin_14B"]: + print("Warning: Missing Standin ref image, make sure 'Inject only People / Objets' is selected or if there is 'Landscape and then People or Objects' there are at least two ref images.") + else: + standin_ref_pos = -1 + image_ref = original_input_ref_images[standin_ref_pos] + face_processor = FaceProcessor() + standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"]) + face_processor = None + gc.collect() + torch.cuda.empty_cache() + standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) + standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0) + kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) + + # Vace if vace : # vace context encode input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)]) input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) - if model_type in ["vace_lynx_14B"]: - input_ref_images = input_ref_masks = None + if model_type in ["vace_lynx_14B"] and input_ref_images is not None: + input_ref_images,input_ref_masks = input_ref_images[:-1], input_ref_masks[:-1] input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images] input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks] ref_images_before = True @@ -744,57 +799,6 @@ def generate(self, kwargs["freqs"] = freqs - # Lynx - if lynx : - if original_input_ref_images is None or len(original_input_ref_images) == 0: - lynx = False - else: - from .lynx.resampler import Resampler - from accelerate import init_empty_weights - lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"] - ip_hidden_states = ip_hidden_states_uncond = None - if True: - with init_empty_weights(): - arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 ) - offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) - arc_resampler.to(self.device) - arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float) - ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype) - ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype) - arc_resampler = None - if not lynx_lite: - standin_ref_pos = -1 - image_ref = original_input_ref_images[standin_ref_pos] - from preprocessing.face_preprocessor import FaceProcessor - face_processor = FaceProcessor() - lynx_ref = face_processor.process(image_ref, resize_to = 256 ) - lynx_ref_buffer, lynx_ref_buffer_uncond = self.encode_reference_images([lynx_ref], tile_size=VAE_tile_size, any_guidance= any_guidance_at_all) - lynx_ref = None - gc.collect() - torch.cuda.empty_cache() - vace_lynx = model_type in ["vace_lynx_14B"] - kwargs["lynx_ip_scale"] = control_scale_alt - kwargs["lynx_ref_scale"] = control_scale_alt - - #Standin - if standin: - from preprocessing.face_preprocessor import FaceProcessor - standin_ref_pos = 1 if "K" in video_prompt_type else 0 - if len(original_input_ref_images) < standin_ref_pos + 1: - if "I" in video_prompt_type and model_type in ["vace_standin_14B"]: - print("Warning: Missing Standin ref image, make sure 'Inject only People / Objets' is selected or if there is 'Landscape and then People or Objects' there are at least two ref images.") - else: - standin_ref_pos = -1 - image_ref = original_input_ref_images[standin_ref_pos] - face_processor = FaceProcessor() - standin_ref = face_processor.process(image_ref, remove_bg = model_type in ["vace_standin_14B"]) - face_processor = None - gc.collect() - torch.cuda.empty_cache() - standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) - standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0) - kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) - # Steps Skipping skip_steps_cache = self.model.cache diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 6c223c833..2185b9cf4 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -97,6 +97,26 @@ def relative_l1_distance(last_tensor, current_tensor): relative_l1_distance = l1_distance / norm return relative_l1_distance.to(torch.float32) +def trim_image_ref(y, ref_images_count, grid_sizes): + y_shape = y.shape + y = y.reshape(y_shape[0], grid_sizes[0], -1) + y = y[:, ref_images_count:] + y = y.reshape(y_shape[0], -1, y_shape[-1]) + grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]] + return y, grid_sizes_alt + +def fuse_with_image_ref(x, y, ref_images_count, grid_sizes, alpha = 1): + y_shape = x.shape + y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1) + x = x.reshape(y_shape[0], grid_sizes[0], -1) + if alpha == 1: + x[:, ref_images_count:] += y + else: + x[:, ref_images_count:].add_(y, alpha= alpha) + + x = x.reshape(*y_shape) + return x + class LoRALinearLayer(nn.Module): def __init__( self, @@ -295,6 +315,8 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks if not lynx_ref_buffer is None: lynx_ref_features = lynx_ref_buffer[self.block_no] if self.norm_q is not None: ref_query = self.norm_q(q, in_place = False) + if ref_images_count > 0: + ref_query, _ = trim_image_ref(ref_query, ref_images_count, grid_sizes) ref_key = self.to_k_ref(lynx_ref_features) ref_value = self.to_v_ref(lynx_ref_features) if self.norm_k is not None: ref_key = self.norm_k(ref_key) @@ -365,7 +387,11 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks del q,k,v if ref_hidden_states is not None: - x.add_(ref_hidden_states, alpha= lynx_ref_scale) + + if ref_images_count > 0: + x = fuse_with_image_ref(x, ref_hidden_states, ref_images_count, grid_sizes, alpha = lynx_ref_scale) + else: + x.add_(ref_hidden_states, alpha= lynx_ref_scale) x = x.flatten(2) x = self.o(x) @@ -631,18 +657,11 @@ def forward( del y x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) else: - y_shape = y.shape - y = y.reshape(y_shape[0], grid_sizes[0], -1) - y = y[:, ref_images_count:] - y = y.reshape(y_shape[0], -1, y_shape[-1]) - grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]] + y, grid_sizes_alt = trim_image_ref(y, ref_images_count, grid_sizes) ylist= [y] y = None y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) - y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1) - x = x.reshape(y_shape[0], grid_sizes[0], -1) - x[:, ref_images_count:] += y - x = x.reshape(y_shape[0], -1, y_shape[-1]) + x = fuse_with_image_ref(x, y, ref_images_count, grid_sizes) del y y = self.norm2(x) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index c8f845d07..060398c98 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -254,9 +254,13 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["vace_lynx_14B"]: extra_model_def["set_video_prompt_type"]="Q" extra_model_def["control_net_weight_alt_name"] = "Lynx" - extra_model_def["image_ref_choices"] = { "choices": [("None", ""),("Person Face", "I")], "letters_filter": "I"} - extra_model_def["no_background_removal"] = True - extra_model_def["fit_into_canvas_image_refs"] = 0 + extra_model_def["image_ref_choices"]["choices"] = [("None", ""), + ("People / Objects", "I"), + ("Landscape followed by People / Objects (if any)", "KI"), + ("Positioned Frames followed by People / Objects (if any)", "FI"), + ] + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape, Lynx Face or Positioned Frames" + extra_model_def["no_processing_on_last_images_refs"] = 1 @@ -617,9 +621,9 @@ def validate_generative_settings(base_model_type, model_def, inputs): inputs["video_prompt_type"] = video_prompt_type - if base_model_type in ["vace_standin_14B"]: + if base_model_type in ["vace_standin_14B", "vace_lynx_14B"]: image_refs = inputs["image_refs"] video_prompt_type = inputs["video_prompt_type"] - if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type: - gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.") + if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type: + gr.Info("Warning, Ref Image that contains the Face to transfer is Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.") diff --git a/shared/utils/utils.py b/shared/utils/utils.py index a03927e62..77e4410c4 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -273,13 +273,13 @@ def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fi image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) return image, new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ): +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False, ignore_last_refs = 0 ): if rm_background: session = new_session() output_list =[] output_mask_list =[] - for i, img in enumerate(img_list): + for i, img in enumerate(img_list if ignore_last_refs == 0 else img_list[:-ignore_last_refs]): width, height = img.size resized_mask = None if any_background_ref == 1 and i==0 or any_background_ref == 2: @@ -312,6 +312,11 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg else: output_list.append(resized_image) output_mask_list.append(resized_mask) + if ignore_last_refs: + for img in img_list[-ignore_last_refs:]: + output_list.append(convert_image_to_tensor(img).unsqueeze(1) if return_tensor else img) + output_mask_list.append(None) + return output_list, output_mask_list def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): diff --git a/wgp.py b/wgp.py index 5bf05582c..9dc14b818 100644 --- a/wgp.py +++ b/wgp.py @@ -5136,7 +5136,8 @@ def remove_temp_filenames(temp_filenames_list): block_size=block_size, outpainting_dims =outpainting_dims, background_ref_outpainted = model_def.get("background_ref_outpainted", True), - return_tensor= model_def.get("return_image_refs_tensor", False) ) + return_tensor= model_def.get("return_image_refs_tensor", False), + ignore_last_refs =model_def.get("no_processing_on_last_images_refs",0)) frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): From e5ccf40404cc8ce88182cd4ff7195278f9dc3ce0 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sat, 4 Oct 2025 01:36:47 +0200 Subject: [PATCH 068/155] loras optim --- README.md | 4 +++- models/wan/any2video.py | 5 ++--- models/wan/wan_handler.py | 27 ++++++++++++++++++++++----- requirements.txt | 2 +- wgp.py | 8 ++++---- 5 files changed, 32 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index db25ed6d0..a022d3fe3 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### October 2 2025: WanGP v8.991 - One Last Thing before the Big Unknown ... +### October 3 2025: WanGP v8.992 - One Last Thing before the Big Unknown ... This new version hasn't any new model... @@ -36,6 +36,8 @@ I have created the new Discord Channel **share-your-settings** where you can pos Last but not least the Lora's documentation has been updated. *update 8.991*: full power of *Vace Lynx* unleashed with new combinations such as Landscape + Face / Clothes + Face / Injectd Frame (Start/End frames/...) + Face +*update 8.992*: optimized gen with Lora, should be 10% if many loras + ### September 30 2025: WanGP v8.9 - Combinatorics This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps. diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 243dddb4b..0ad61e629 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -708,8 +708,7 @@ def generate(self, ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype) arc_resampler = None if not lynx_lite: - standin_ref_pos = -1 - image_ref = original_input_ref_images[standin_ref_pos] + image_ref = original_input_ref_images[-1] from preprocessing.face_preprocessor import FaceProcessor face_processor = FaceProcessor() lynx_ref = face_processor.process(image_ref, resize_to = 256 ) @@ -736,7 +735,7 @@ def generate(self, face_processor = None gc.collect() torch.cuda.empty_cache() - standin_freqs = get_nd_rotary_pos_embed((-1, int(target_shape[-2]/2), int(target_shape[-1]/2) ), (-1, int(target_shape[-2]/2 + standin_ref.height/16), int(target_shape[-1]/2 + standin_ref.width/16) )) + standin_freqs = get_nd_rotary_pos_embed((-1, int(height/16), int(width/16) ), (-1, int(height/16 + standin_ref.height/16), int(width/16 + standin_ref.width/16) )) standin_ref = self.vae.encode([ convert_image_to_tensor(standin_ref).unsqueeze(1) ], VAE_tile_size)[0].unsqueeze(0) kwargs.update({ "standin_freqs": standin_freqs, "standin_ref": standin_ref, }) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index 060398c98..a5d1e3d98 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -255,10 +255,9 @@ def query_model_def(base_model_type, model_def): extra_model_def["set_video_prompt_type"]="Q" extra_model_def["control_net_weight_alt_name"] = "Lynx" extra_model_def["image_ref_choices"]["choices"] = [("None", ""), - ("People / Objects", "I"), - ("Landscape followed by People / Objects (if any)", "KI"), - ("Positioned Frames followed by People / Objects (if any)", "FI"), - ] + ("People / Objects (if any) then a Face", "I"), + ("Landscape followed by People / Objects (if any) then a Face", "KI"), + ("Positioned Frames followed by People / Objects (if any) then a Face", "FI")] extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape, Lynx Face or Positioned Frames" extra_model_def["no_processing_on_last_images_refs"] = 1 @@ -266,6 +265,15 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["standin"]: extra_model_def["v2i_switch_supported"] = True + extra_model_def["image_ref_choices"] = { + "choices": [ + ("No Reference Image", ""), + ("Reference Image is a Person Face", "I"), + ], + "visible": False, + "letters_filter":"I", + } + extra_model_def["one_image_ref_needed"] = True if base_model_type in ["lynx_lite", "lynx"]: extra_model_def["fit_into_canvas_image_refs"] = 0 @@ -292,7 +300,7 @@ def query_model_def(base_model_type, model_def): "visible": False, "letters_filter":"I", } - + extra_model_def["one_image_ref_needed"] = True extra_model_def["set_video_prompt_type"]= "Q" extra_model_def["no_background_removal"] = True extra_model_def["v2i_switch_supported"] = True @@ -515,6 +523,13 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): video_prompt_type = video_prompt_type.replace("Q", "0") ui_defaults["video_prompt_type"] = video_prompt_type + if settings_version < 2.39: + if base_model_type in ["fantasy"]: + audio_prompt_type = ui_defaults.get("audio_prompt_type", "") + if not "A" in audio_prompt_type: + audio_prompt_type += "A" + ui_defaults["audio_prompt_type"] = audio_prompt_type + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ @@ -527,6 +542,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "audio_guidance_scale": 5.0, "sliding_window_overlap" : 1, + "audio_prompt_type": "A", }) elif base_model_type in ["multitalk"]: @@ -535,6 +551,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "flow_shift": 7, # 11 for 720p "sliding_window_discard_last_frames" : 4, "sample_solver" : "euler", + "audio_prompt_type": "A", "adaptive_switch" : 1, }) diff --git a/requirements.txt b/requirements.txt index 9b2c6eeb6..a17efd108 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,7 +48,7 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.6.0 +mmgp==3.6.1 peft==0.15.0 matplotlib diff --git a/wgp.py b/wgp.py index 9dc14b818..090bcceb2 100644 --- a/wgp.py +++ b/wgp.py @@ -62,9 +62,9 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.6.0" -WanGP_version = "8.99" -settings_version = 2.38 +target_mmgp_version = "3.6.1" +WanGP_version = "8.992" +settings_version = 2.39 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -5011,7 +5011,7 @@ def remove_temp_filenames(temp_filenames_list): aligned_guide_start_frame = guide_start_frame - alignment_shift aligned_guide_end_frame = guide_end_frame - alignment_shift aligned_window_start_frame = window_start_frame - alignment_shift - if fantasy: + if fantasy and audio_guide is not None: audio_proj_split , audio_context_lens = parse_audio(audio_guide, start_frame = aligned_window_start_frame, num_frames= current_video_length, fps= fps, device= processing_device ) if multitalk: from models.wan.multitalk.multitalk import get_window_audio_embeddings From 798ad52e3f92a7cdfde020503f96a25c6bf762b3 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 5 Oct 2025 17:54:55 +0200 Subject: [PATCH 069/155] added scaled fp8 support, palangenesis finetunes, merged accelerators / non accelerators loras --- README.md | 11 +- defaults/vace_14B_cocktail_2_2.json | 2 +- defaults/vace_14B_fusionix.json | 2 +- defaults/vace_fun_14B_cocktail_2_2.json | 1 + docs/LORAS.md | 4 +- .../Text2Video Cocktail - 10 Steps.json | 0 .../Text2Video FusioniX - 10 Steps.json | 0 ...x2v Cfg Step Distill Rank32 - 4 Steps.json | 0 ...un Inp HPS2 Reward 2 Phases - 4 Steps.json | 14 + ...Fun Inp MPS Reward 2 Phases - 4 Steps.json | 14 + .../Lightning v1.0 2 Phases - 4 Steps.json | 2 +- .../Lightning v1.0 3 Phases - 8 Steps.json | 2 +- .../Lightning v2025-09-28 - 4 Steps.json | 0 ...htning v2025-09-28 2 Phases - 4 Steps.json | 14 + ...htning v2025-09-28 3 Phases - 8 Steps.json | 18 ++ .../Text2Video FusioniX - 10 Steps.json | 1 + .../FastWan 3 Steps.json | 0 .../Image2Video FusioniX - 10 Steps.json | 0 .../Lightning v1.0 - 4 Steps.json | 0 .../Lightning v1.1 - 8 Steps.json | 0 models/wan/any2video.py | 17 +- models/wan/diffusion_forcing.py | 2 +- requirements.txt | 2 +- shared/utils/loras_mutipliers.py | 294 +++++++++++++++++- wgp.py | 43 ++- 25 files changed, 412 insertions(+), 31 deletions(-) rename loras/{wan => profiles_wan}/Text2Video Cocktail - 10 Steps.json (100%) rename loras/{wan => profiles_wan}/Text2Video FusioniX - 10 Steps.json (100%) rename loras/{wan => profiles_wan}/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json (100%) create mode 100644 loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json create mode 100644 loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json rename loras/{wan2_2 => profiles_wan2_2}/Lightning v1.0 2 Phases - 4 Steps.json (95%) rename loras/{wan2_2 => profiles_wan2_2}/Lightning v1.0 3 Phases - 8 Steps.json (97%) rename loras/{wan2_2 => profiles_wan2_2}/Lightning v2025-09-28 - 4 Steps.json (100%) create mode 100644 loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json create mode 100644 loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json rename loras/{wan2_2 => profiles_wan2_2}/Text2Video FusioniX - 10 Steps.json (91%) rename loras/{wan2_2_5B => profiles_wan2_2_5B}/FastWan 3 Steps.json (100%) rename loras_i2v/{wan => profiles_wan}/Image2Video FusioniX - 10 Steps.json (100%) rename loras_qwen/{qwen => profiles_qwen}/Lightning v1.0 - 4 Steps.json (100%) rename loras_qwen/{qwen => profiles_qwen}/Lightning v1.1 - 8 Steps.json (100%) diff --git a/README.md b/README.md index a022d3fe3..c4ea1f942 100644 --- a/README.md +++ b/README.md @@ -20,12 +20,12 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### October 3 2025: WanGP v8.992 - One Last Thing before the Big Unknown ... +### October 5 2025: WanGP v8.993 - One Last Thing before the Big Unknown ... This new version hasn't any new model... ...but temptation to upgrade will be high as it contains a few Loras related features that may change your Life: -- **Ready to use Loras Accelerators Profiles** per type of model that you can apply on your current *Generation Settings*. Next time I will recommend a *Lora Accelerator*, it will be only one click away. And best of all of the required Loras will be downloaded automatically. When you apply an *Accelerator Profile*, input fields like the *Number of Denoising Steps* *Activated Loras*, *Loras Multipliers* (such as "1;0 0;1" ...) will be automatically filled. However your video specific fields will be preserved, so it will be easy to switch between Profiles to experiment. +- **Ready to use Loras Accelerators Profiles** per type of model that you can apply on your current *Generation Settings*. Next time I will recommend a *Lora Accelerator*, it will be only one click away. And best of all of the required Loras will be downloaded automatically. When you apply an *Accelerator Profile*, input fields like the *Number of Denoising Steps* *Activated Loras*, *Loras Multipliers* (such as "1;0 0;1" ...) will be automatically filled. However your video specific fields will be preserved, so it will be easy to switch between Profiles to experiment. With *WanGP 8.993*, the *Accelerator Loras* are now merged with *Non Accelerator Loras". Things are getting too easy... - **Embedded Loras URL** : WanGP will now try to remember every Lora URLs it sees. For instance if someone sends you some settings that contain Loras URLs or you extract the Settings of Video generated by a friend with Loras URLs, these URLs will be automatically added to *WanGP URL Cache*. Conversely everything you will share (Videos, Settings, Lset files) will contain the download URLs if they are known. You can also download directly a Lora in WanGP by using the *Download Lora* button a the bottom. The Lora will be immediatly available and added to WanGP lora URL cache. This will work with *Hugging Face* as a repository. Support for CivitAi will come as soon as someone will nice enough to post a GitHub PR ... @@ -33,10 +33,15 @@ This new version hasn't any new model... I have created the new Discord Channel **share-your-settings** where you can post your *Settings* or *Lset files*. I will be pleased to add new Loras Accelerators in the list of WanGP *Accelerators Profiles if you post some good ones there. +*With the 8.993 update*, I have added out support for **Scaled FP8 format**. As a sample case, I have created finetunes for the **Wan 2.2 PalinGenesis** Finetune which is quite popular recently. You will find it in 3 flavors : *t2v*, *i2v* and *Lightning Accelerated for t2v*. + +The *Scaled FP8 format* is very popular as it the format used by ... *ComfyUI*. So I except a flood of Finetunes in the *share-your-finetune* channel. If not it means this feature was useless and I will remove it 😈😈😈 + Last but not least the Lora's documentation has been updated. *update 8.991*: full power of *Vace Lynx* unleashed with new combinations such as Landscape + Face / Clothes + Face / Injectd Frame (Start/End frames/...) + Face -*update 8.992*: optimized gen with Lora, should be 10% if many loras +*update 8.992*: optimized gen with Lora, should be 10% faster if many loras +*update 8.993*: Support for *Scaled FP8* format and samples *Paligenesis* finetunes, merged Loras Accelerators and Non Accelerators ### September 30 2025: WanGP v8.9 - Combinatorics diff --git a/defaults/vace_14B_cocktail_2_2.json b/defaults/vace_14B_cocktail_2_2.json index 6a831d214..57839227e 100644 --- a/defaults/vace_14B_cocktail_2_2.json +++ b/defaults/vace_14B_cocktail_2_2.json @@ -14,7 +14,7 @@ "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" ], - "settings_dir": [ "" ], + "profiles_dir": [ "" ], "loras_multipliers": [1, 0.2, 0.5, 0.5], "group": "wan2_2" }, diff --git a/defaults/vace_14B_fusionix.json b/defaults/vace_14B_fusionix.json index 1c6fa7288..95639f404 100644 --- a/defaults/vace_14B_fusionix.json +++ b/defaults/vace_14B_fusionix.json @@ -6,7 +6,7 @@ "vace_14B" ], "description": "Vace control model enhanced using multiple open-source components and LoRAs to boost motion realism, temporal consistency, and expressive detail.", - "settings_dir": [ "" ], + "profiles_dir": [ "" ], "URLs": "t2v_fusionix" }, "negative_prompt": "", diff --git a/defaults/vace_fun_14B_cocktail_2_2.json b/defaults/vace_fun_14B_cocktail_2_2.json index c587abdd4..0a734380f 100644 --- a/defaults/vace_fun_14B_cocktail_2_2.json +++ b/defaults/vace_fun_14B_cocktail_2_2.json @@ -17,6 +17,7 @@ 0.5, 0.5 ], + "profiles_dir": [""], "group": "wan2_2" }, "guidance_phases": 2, diff --git a/docs/LORAS.md b/docs/LORAS.md index dc0caaa56..f17b51f0b 100644 --- a/docs/LORAS.md +++ b/docs/LORAS.md @@ -196,7 +196,9 @@ There are three ways to setup Loras accelerators: Some models Finetunes such as *Vace FusioniX* or *Vace Coctail* have the Loras Accelerators already set up in their own definition and you won't have to do anything as they will be downloaded with the Finetune. 2) **Accelerators Profiles** -Predefined *Accelerator Profiles* can be selected using the *Settings* Dropdown box at the top. The choices of Accelerators will depend on the models. No accelerator will be offered if the finetune / model is already accelerated. Just click *Apply* and the Accelerators Loras will be setup in the Loras tab at the bottom. Any missing Lora will be downloaded automatically the first time you try to generate a Video. Be aware that when applying an *Accelerator Profile*, inputs such as *Activated Loras*, *Number of Inference Steps*, ... will bo overwritten wheras other inputs will be preserved so that you can easily switch between profiles. +Predefined *Accelerator Profiles* can be selected using the *Settings* Dropdown box at the top. The choices of Accelerators will depend on the models. No accelerator will be offered if the finetune / model is already accelerated. Just click *Apply* and the Accelerators Loras will be setup in the Loras tab at the bottom. Any missing Lora will be downloaded automatically the first time you try to generate a Video. Be aware that when applying an *Accelerator Profile*, inputs such as *Activated Loras*, *Number of Inference Steps*, ... will be updated. However if you have already existing Loras set up (that are non Loras Accelerators) they will be preserved so that you can easily switch between Accelerators Profiles. + +You will see the "|" character at the end of the Multipliers text input associated to Loras Accelerators. It plays the same role than the Space character to separate Multipliers except it tells WanGP where the Loras Accelerators multipliers end so that it can merge Loras Accelerators with Non Loras Accelerators. 3) **Manual Install** - Download the Lora diff --git a/loras/wan/Text2Video Cocktail - 10 Steps.json b/loras/profiles_wan/Text2Video Cocktail - 10 Steps.json similarity index 100% rename from loras/wan/Text2Video Cocktail - 10 Steps.json rename to loras/profiles_wan/Text2Video Cocktail - 10 Steps.json diff --git a/loras/wan/Text2Video FusioniX - 10 Steps.json b/loras/profiles_wan/Text2Video FusioniX - 10 Steps.json similarity index 100% rename from loras/wan/Text2Video FusioniX - 10 Steps.json rename to loras/profiles_wan/Text2Video FusioniX - 10 Steps.json diff --git a/loras/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json b/loras/profiles_wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json similarity index 100% rename from loras/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json rename to loras/profiles_wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json diff --git a/loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json new file mode 100644 index 000000000..1b58ec73a --- /dev/null +++ b/loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 8, + "guidance_scale": 1, + "guidance2_scale": 1, + "model_switch_phase": 1, + "switch_threshold": 876, + "guidance_phases": 2, + "flow_shift": 3, + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-high-noise-HPS2.1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-low-noise-HPS2.1.safetensors" + ], + "loras_multipliers": "1;0 0;1" +} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json new file mode 100644 index 000000000..cc5742248 --- /dev/null +++ b/loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "model_switch_phase": 1, + "switch_threshold": 876, + "guidance_phases": 2, + "flow_shift": 3, + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-high-noise-MPS.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-low-noise-MPS.safetensors" + ], + "loras_multipliers": "1;0 0;1" +} \ No newline at end of file diff --git a/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Lightning v1.0 2 Phases - 4 Steps.json similarity index 95% rename from loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json rename to loras/profiles_wan2_2/Lightning v1.0 2 Phases - 4 Steps.json index e7aeedad9..51c903b63 100644 --- a/loras/wan2_2/Lightning v1.0 2 Phases - 4 Steps.json +++ b/loras/profiles_wan2_2/Lightning v1.0 2 Phases - 4 Steps.json @@ -6,7 +6,7 @@ "switch_threshold": 876, "guidance_phases": 2, "flow_shift": 3, - "loras": [ + "activated_loras": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" ], diff --git a/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json b/loras/profiles_wan2_2/Lightning v1.0 3 Phases - 8 Steps.json similarity index 97% rename from loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json rename to loras/profiles_wan2_2/Lightning v1.0 3 Phases - 8 Steps.json index 681cf3077..a30834482 100644 --- a/loras/wan2_2/Lightning v1.0 3 Phases - 8 Steps.json +++ b/loras/profiles_wan2_2/Lightning v1.0 3 Phases - 8 Steps.json @@ -9,7 +9,7 @@ "model_switch_phase": 2, "flow_shift": 3, "sample_solver": "euler", - "loras": [ + "activated_loras": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" ], diff --git a/loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json b/loras/profiles_wan2_2/Lightning v2025-09-28 - 4 Steps.json similarity index 100% rename from loras/wan2_2/Lightning v2025-09-28 - 4 Steps.json rename to loras/profiles_wan2_2/Lightning v2025-09-28 - 4 Steps.json diff --git a/loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json new file mode 100644 index 000000000..31e4f0d66 --- /dev/null +++ b/loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "switch_threshold": 876, + "model_switch_phase": 1, + "guidance_phases": 2, + "flow_shift": 3, + "loras_multipliers": "1;0 0;1", + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + ] +} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json b/loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json new file mode 100644 index 000000000..839e0564c --- /dev/null +++ b/loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json @@ -0,0 +1,18 @@ +{ + "num_inference_steps": 8, + "guidance_scale": 3.5, + "guidance2_scale": 1, + "guidance3_scale": 1, + "switch_threshold": 965, + "switch_threshold2": 800, + "guidance_phases": 3, + "model_switch_phase": 2, + "flow_shift": 3, + "sample_solver": "euler", + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + ], + "loras_multipliers": "0;1;0 0;0;1", + "help": "This finetune uses the Lightning 250928 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators." +} \ No newline at end of file diff --git a/loras/wan2_2/Text2Video FusioniX - 10 Steps.json b/loras/profiles_wan2_2/Text2Video FusioniX - 10 Steps.json similarity index 91% rename from loras/wan2_2/Text2Video FusioniX - 10 Steps.json rename to loras/profiles_wan2_2/Text2Video FusioniX - 10 Steps.json index 3695bec63..1ad921b63 100644 --- a/loras/wan2_2/Text2Video FusioniX - 10 Steps.json +++ b/loras/profiles_wan2_2/Text2Video FusioniX - 10 Steps.json @@ -7,5 +7,6 @@ "num_inference_steps": 10, "guidance_scale": 1, "guidance2_scale": 1, + "guidance_phases": 2, "flow_shift": 2 } diff --git a/loras/wan2_2_5B/FastWan 3 Steps.json b/loras/profiles_wan2_2_5B/FastWan 3 Steps.json similarity index 100% rename from loras/wan2_2_5B/FastWan 3 Steps.json rename to loras/profiles_wan2_2_5B/FastWan 3 Steps.json diff --git a/loras_i2v/wan/Image2Video FusioniX - 10 Steps.json b/loras_i2v/profiles_wan/Image2Video FusioniX - 10 Steps.json similarity index 100% rename from loras_i2v/wan/Image2Video FusioniX - 10 Steps.json rename to loras_i2v/profiles_wan/Image2Video FusioniX - 10 Steps.json diff --git a/loras_qwen/qwen/Lightning v1.0 - 4 Steps.json b/loras_qwen/profiles_qwen/Lightning v1.0 - 4 Steps.json similarity index 100% rename from loras_qwen/qwen/Lightning v1.0 - 4 Steps.json rename to loras_qwen/profiles_qwen/Lightning v1.0 - 4 Steps.json diff --git a/loras_qwen/qwen/Lightning v1.1 - 8 Steps.json b/loras_qwen/profiles_qwen/Lightning v1.1 - 8 Steps.json similarity index 100% rename from loras_qwen/qwen/Lightning v1.1 - 8 Steps.json rename to loras_qwen/profiles_qwen/Lightning v1.1 - 8 Steps.json diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 0ad61e629..01c2cd832 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -93,7 +93,6 @@ def __init__( checkpoint_path=text_encoder_filename, tokenizer_path=os.path.join(checkpoint_dir, "umt5-xxl"), shard_fn= None) - # base_model_type = "i2v2_2" if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]: self.clip = CLIPModel( @@ -103,6 +102,7 @@ def __init__( config.clip_checkpoint), tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large")) + ignore_unused_weights = model_def.get("ignore_unused_weights", False) if base_model_type in ["ti2v_2_2", "lucy_edit"]: self.vae_stride = (4, 16, 16) @@ -134,10 +134,11 @@ def __init__( source2 = model_def.get("source2", None) module_source = model_def.get("module_source", None) module_source2 = model_def.get("module_source2", None) + kwargs= { "ignore_unused_weights": ignore_unused_weights, "writable_tensors": False, "default_dtype": dtype } if module_source is not None: - self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) if module_source2 is not None: - self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) if source is not None: self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) if source2 is not None: @@ -153,17 +154,17 @@ def __init__( if 0 in submodel_no_list[2:]: shared_modules= {} - self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) - self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules, **kwargs) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) shared_modules = None else: modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ] modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ] - self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) else: - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) if self.model is not None: diff --git a/models/wan/diffusion_forcing.py b/models/wan/diffusion_forcing.py index 5d5e42ffd..09d73150a 100644 --- a/models/wan/diffusion_forcing.py +++ b/models/wan/diffusion_forcing.py @@ -365,7 +365,7 @@ def generate( timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition kwrags.update({ "t" : timestep, - "current_step" : i, + "current_step_no" : i, }) # with torch.autocast(device_type="cuda"): diff --git a/requirements.txt b/requirements.txt index a17efd108..cfe61c57f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,7 +48,7 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.6.1 +mmgp==3.6.2 peft==0.15.0 matplotlib diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py index e0cec6a45..8ba456d4d 100644 --- a/shared/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -1,3 +1,6 @@ +from typing import List, Tuple, Dict, Callable + + def preparse_loras_multipliers(loras_multipliers): if isinstance(loras_multipliers, list): return [multi.strip(" \r\n") if isinstance(multi, str) else multi for multi in loras_multipliers] @@ -6,7 +9,7 @@ def preparse_loras_multipliers(loras_multipliers): loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] loras_multipliers = " ".join(loras_mult_choices_list) - return loras_multipliers.split(" ") + return loras_multipliers.replace("|"," ").split(" ") def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ): def expand_one(slist, num_inference_steps): @@ -33,6 +36,10 @@ def expand_one(slist, num_inference_steps): return expand_one(phase1, model_switch_step) + expand_one(phase2, model_switch_step2 - model_switch_step) + expand_one(phase3, num_inference_steps - model_switch_step2) def parse_loras_multipliers(loras_multipliers, nb_loras, num_inference_steps, merge_slist = None, nb_phases = 2, model_switch_step = None, model_switch_step2 = None): + if "|" in loras_multipliers: + pos = loras_multipliers.find("|") + if "|" in loras_multipliers[pos+1:]: return "", "", "There can be only one '|' character in Loras Multipliers Sequence" + if model_switch_step is None: model_switch_step = num_inference_steps if model_switch_step2 is None: @@ -125,3 +132,288 @@ def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_t if guide_phases > 2 and model_switch_step2 < total_num_steps: phases_description += f", Phase 3 = {model_switch_step2 +1}:{ total_num_steps}" return model_switch_step, model_switch_step2, phases_description + + + +from typing import List, Tuple, Dict, Callable + +_ALWD = set(":;,.0123456789") + +# ---------------- core parsing helpers ---------------- + +def _find_bar(s: str) -> int: + com = False + for i, ch in enumerate(s): + if ch in ('\n', '\r'): + com = False + elif ch == '#': + com = True + elif ch == '|' and not com: + return i + return -1 + +def _spans(text: str) -> List[Tuple[int, int]]: + res, com, in_tok, st = [], False, False, 0 + for i, ch in enumerate(text): + if ch in ('\n', '\r'): + if in_tok: res.append((st, i)); in_tok = False + com = False + elif ch == '#': + if in_tok: res.append((st, i)); in_tok = False + com = True + elif not com: + if ch in _ALWD: + if not in_tok: in_tok, st = True, i + else: + if in_tok: res.append((st, i)); in_tok = False + if in_tok: res.append((st, len(text))) + return res + +def _choose_sep(text: str, spans: List[Tuple[int, int]]) -> str: + if len(spans) >= 2: + a, b = spans[-2][1], spans[-1][0] + return '\n' if ('\n' in text[a:b] or '\r' in text[a:b]) else ' ' + return '\n' if ('\n' in text or '\r' in text) else ' ' + +def _ends_in_comment_line(text: str) -> bool: + ln = text.rfind('\n') + seg = text[ln + 1:] if ln != -1 else text + return '#' in seg + +def _append_tokens(text: str, k: int, sep: str) -> str: + if k <= 0: return text + t = text + if _ends_in_comment_line(t) and (not t.endswith('\n')): t += '\n' + parts = [] + if t and not t[-1].isspace(): parts.append(sep) + parts.append('1') + for _ in range(k - 1): + parts.append(sep); parts.append('1') + return t + ''.join(parts) + +def _erase_span_and_one_sep(text: str, st: int, en: int) -> str: + n = len(text) + r = en + while r < n and text[r] in (' ', '\t'): r += 1 + if r > en: return text[:st] + text[r:] + l = st + while l > 0 and text[l-1] in (' ', '\t'): l -= 1 + if l < st: return text[:l] + text[en:] + return text[:st] + text[en:] + +def _trim_last_tokens(text: str, spans: List[Tuple[int, int]], drop: int) -> str: + if drop <= 0: return text + new_text = text + for st, en in reversed(spans[-drop:]): + new_text = _erase_span_and_one_sep(new_text, st, en) + while new_text and new_text[-1] in (' ', '\t'): + new_text = new_text[:-1] + return new_text + +def _enforce_count(text: str, target: int) -> str: + sp = _spans(text); cur = len(sp) + if cur == target: return text + if cur > target: return _trim_last_tokens(text, sp, cur - target) + sep = _choose_sep(text, sp) + return _append_tokens(text, target - cur, sep) + +def _strip_bars_outside_comments(s: str) -> str: + com, out = False, [] + for ch in s: + if ch in ('\n', '\r'): com = False; out.append(ch) + elif ch == '#': com = True; out.append(ch) + elif ch == '|' and not com: continue + else: out.append(ch) + return ''.join(out) + +def _replace_tokens(text: str, repl: Dict[int, str]) -> str: + if not repl: return text + sp = _spans(text) + for idx in sorted(repl.keys(), reverse=True): + if 0 <= idx < len(sp): + st, en = sp[idx] + text = text[:st] + repl[idx] + text[en:] + return text + +def _drop_tokens_by_indices(text: str, idxs: List[int]) -> str: + if not idxs: return text + out = text + for idx in sorted(set(idxs), reverse=True): + sp = _spans(out) # recompute spans after each deletion + if 0 <= idx < len(sp): + st, en = sp[idx] + out = _erase_span_and_one_sep(out, st, en) + return out + +# ---------------- identity for dedupe ---------------- + +def _default_path_key(p: str) -> str: + s = p.strip().replace('\\', '/') + while '//' in s: s = s.replace('//', '/') + if len(s) > 1 and s.endswith('/'): s = s[:-1] + return s + +# ---------------- new-set splitter (FIX) ---------------- + +def _select_new_side( + loras_new: List[str], + mult_new: str, + mode: str, # "merge before" | "merge after" +) -> Tuple[List[str], str]: + """ + Split mult_new on '|' (outside comments) and split loras_new accordingly. + Return ONLY the side relevant to `mode`. Extras loras (if any) are appended to the selected side. + """ + bi = _find_bar(mult_new) + if bi == -1: + return loras_new, _strip_bars_outside_comments(mult_new) + + left, right = mult_new[:bi], mult_new[bi + 1:] + nL, nR = len(_spans(left)), len(_spans(right)) + L = len(loras_new) + + # Primary allocation by token counts + b_count = min(nL, L) + rem = max(0, L - b_count) + a_count = min(nR, rem) + extras = max(0, L - (b_count + a_count)) + + if mode == "merge before": + # take BEFORE loras + extras + l_sel = loras_new[:b_count] + (loras_new[b_count + a_count : b_count + a_count + extras] if extras else []) + m_sel = left + else: + # take AFTER loras + extras + start_after = b_count + l_sel = loras_new[start_after:start_after + a_count] + (loras_new[start_after + a_count : start_after + a_count + extras] if extras else []) + m_sel = right + + return l_sel, _strip_bars_outside_comments(m_sel) + +# ---------------- public API ---------------- + +def merge_loras_settings( + loras_old: List[str], + mult_old: str, + loras_new: List[str], + mult_new: str, + mode: str = "merge before", + path_key: Callable[[str], str] = _default_path_key, +) -> Tuple[List[str], str]: + """ + Merge settings with full formatting/comment preservation and correct handling of `mult_new` with '|'. + Dedup rule: when merging AFTER (resp. BEFORE), if a new lora already exists in preserved BEFORE (resp. AFTER), + update that preserved multiplier and drop the duplicate from the replaced side. + """ + assert mode in ("merge before", "merge after") + + # Old split & alignment + bi_old = _find_bar(mult_old) + before_old, after_old = (mult_old[:bi_old], mult_old[bi_old + 1:]) if bi_old != -1 else ("", mult_old) + orig_had_bar = (bi_old != -1) + + sp_b_old, sp_a_old = _spans(before_old), _spans(after_old) + n_b_old = len(sp_b_old) + total_old = len(loras_old) + + if n_b_old <= total_old: + keep_b = n_b_old + keep_a = total_old - keep_b + before_old_aligned = before_old + after_old_aligned = _enforce_count(after_old, keep_a) + else: + keep_b = total_old + keep_a = 0 + before_old_aligned = _enforce_count(before_old, keep_b) + after_old_aligned = _enforce_count(after_old, 0) + + # NEW: choose the relevant side of the *new* set (fix for '|' in mult_new) + loras_new_sel, mult_new_sel = _select_new_side(loras_new, mult_new, mode) + mult_new_aligned = _enforce_count(mult_new_sel, len(loras_new_sel)) + sp_new = _spans(mult_new_aligned) + new_tokens = [mult_new_aligned[st:en] for st, en in sp_new] + + if mode == "merge after": + # Preserve BEFORE; replace AFTER (with dedupe/update) + preserved_loras = loras_old[:keep_b] + preserved_text = before_old_aligned + preserved_spans = _spans(preserved_text) + pos_by_key: Dict[str, int] = {} + for i, lp in enumerate(preserved_loras): + k = path_key(lp) + if k not in pos_by_key: pos_by_key[k] = i + + repl_map: Dict[int, str] = {} + drop_idxs: List[int] = [] + for i, lp in enumerate(loras_new_sel): + j = pos_by_key.get(path_key(lp)) + if j is not None and j < len(preserved_spans): + repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1" + drop_idxs.append(i) + + before_text = _replace_tokens(preserved_text, repl_map) + after_text = _drop_tokens_by_indices(mult_new_aligned, drop_idxs) + loras_keep = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)] + loras_out = preserved_loras + loras_keep + + else: + # Preserve AFTER; replace BEFORE (with dedupe/update) + preserved_loras = loras_old[keep_b:] + preserved_text = after_old_aligned + preserved_spans = _spans(preserved_text) + pos_by_key: Dict[str, int] = {} + for i, lp in enumerate(preserved_loras): + k = path_key(lp) + if k not in pos_by_key: pos_by_key[k] = i + + repl_map: Dict[int, str] = {} + drop_idxs: List[int] = [] + for i, lp in enumerate(loras_new_sel): + j = pos_by_key.get(path_key(lp)) + if j is not None and j < len(preserved_spans): + repl_map[j] = new_tokens[i] if i < len(new_tokens) else "1" + drop_idxs.append(i) + + after_text = _replace_tokens(preserved_text, repl_map) + before_text = _drop_tokens_by_indices(mult_new_aligned, drop_idxs) + loras_keep = [lp for i, lp in enumerate(loras_new_sel) if i not in set(drop_idxs)] + loras_out = loras_keep + preserved_loras + + # Compose, preserving explicit "before-only" bar when appropriate + has_before = len(_spans(before_text)) > 0 + has_after = len(_spans(after_text)) > 0 + if has_before and has_after: + mult_out = f"{before_text}|{after_text}" + elif has_before: + mult_out = before_text + ('|' if (mode == 'merge before' or orig_had_bar) else '') + else: + mult_out = after_text + + return loras_out, mult_out + +# ---------------- extractor ---------------- + +def extract_loras_side( + loras: List[str], + mult: str, + which: str = "before", +) -> Tuple[List[str], str]: + assert which in ("before", "after") + bi = _find_bar(mult) + before_txt, after_txt = (mult[:bi], mult[bi + 1:]) if bi != -1 else ("", mult) + + sp_b = _spans(before_txt) + n_b = len(sp_b) + total = len(loras) + + if n_b <= total: + keep_b = n_b + keep_a = total - keep_b + else: + keep_b = total + keep_a = 0 + + if which == "before": + return loras[:keep_b], _enforce_count(before_txt, keep_b) + else: + return loras[keep_b:keep_b + keep_a], _enforce_count(after_txt, keep_a) diff --git a/wgp.py b/wgp.py index 090bcceb2..ff82489cf 100644 --- a/wgp.py +++ b/wgp.py @@ -62,8 +62,8 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.6.1" -WanGP_version = "8.992" +target_mmgp_version = "3.6.2" +WanGP_version = "8.993" settings_version = 2.39 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -5843,8 +5843,8 @@ def compute_lset_choices(model_type, loras_presets): if model_type is not None: top_dir = get_lora_dir(model_type) model_def = get_model_def(model_type) - settings_dir = get_model_recursive_prop(model_type, "settings_dir", return_list=False) - if settings_dir is None or len(settings_dir) == 0: settings_dir = [get_model_family(model_type, True)] + settings_dir = get_model_recursive_prop(model_type, "profiles_dir", return_list=False) + if settings_dir is None or len(settings_dir) == 0: settings_dir = ["profiles_" + get_model_family(model_type, True)] for dir in settings_dir: if len(dir) == "": continue cur_path = os.path.join(top_dir, dir) @@ -5936,6 +5936,8 @@ def save_lset(state, lset_name, loras_choices, loras_mult_choices, prompt, save_ lset = collect_current_model_settings(state) extension = ".json" else: + from shared.utils.loras_mutipliers import extract_loras_side + loras_choices, loras_mult_choices = extract_loras_side(loras_choices, loras_mult_choices, "after") lset = {"loras" : loras_choices, "loras_mult" : loras_mult_choices} if save_lset_prompt_cbox!=1: prompts = prompt.replace("\r", "").split("\n") @@ -6046,6 +6048,7 @@ def refresh_lora_list(state, lset_name, loras_choices): def update_lset_type(state, lset_name): return 1 if lset_name.endswith(".lset") else 2 +from shared.utils.loras_mutipliers import merge_loras_settings def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_mult_choices, prompt): @@ -6057,10 +6060,11 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, gr.update(), gr.update(), gr.update(), gr.update() else: current_model_type = state["model_type"] + ui_settings = get_current_model_settings(state) + old_activated_loras, old_loras_multipliers = ui_settings.get("activated_loras", []), ui_settings.get("loras_multipliers", ""), if lset_name.endswith(".lset"): loras = state["loras"] loras_choices, loras_mult_choices, preset_prompt, full_prompt, error = extract_preset(current_model_type, lset_name, loras) - loras_choices = update_loras_url_cache(get_lora_dir(current_model_type), loras_choices) if full_prompt: prompt = preset_prompt elif len(preset_prompt) > 0: @@ -6068,6 +6072,9 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m prompts = [prompt for prompt in prompts if len(prompt)>0 and not prompt.startswith("#")] prompt = "\n".join(prompts) prompt = preset_prompt + '\n' + prompt + loras_choices, loras_mult_choices = merge_loras_settings(old_activated_loras, old_loras_multipliers, loras_choices, loras_mult_choices, "merge after") + loras_choices = update_loras_url_cache(get_lora_dir(current_model_type), loras_choices) + loras_choices = [os.path.basename(lora) for lora in loras_choices] gr.Info(f"Lora Preset '{lset_name}' has been applied") state["apply_success"] = 1 wizard_prompt_activated = "on" @@ -6076,7 +6083,8 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m else: # lset_path = os.path.join("settings", lset_name) if len(Path(lset_name).parts)>1 else os.path.join(get_lora_dir(current_model_type), lset_name) lset_path = os.path.join(get_lora_dir(current_model_type), lset_name) - configs, _ = get_settings_from_file(state,lset_path , True, True, True, min_settings_version=2.38) + configs, _ = get_settings_from_file(state,lset_path , True, True, True, min_settings_version=2.38, merge_loras = "merge after" if len(Path(lset_name).parts)<=1 else "merge before" ) + if configs == None: gr.Info("File not supported") return [gr.update()] * 8 @@ -6085,7 +6093,6 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m gr.Info(f"Settings File '{lset_name}' has been applied") help = configs.get("help", None) if help is not None: gr.Info(help) - if model_type == current_model_type: set_model_settings(state, current_model_type, configs) return *[gr.update()] * 4, gr.update(), gr.update(), gr.update(), get_unique_id() @@ -6625,7 +6632,7 @@ def update_loras_url_cache(lora_dir, loras_selected): base_name = os.path.basename(lora) local_name = os.path.join(lora_dir, base_name) url = loras_url_cache.get(local_name, base_name) - if lora.startswith("http:") or lora.startswith("https:") and url != lora: + if (lora.startswith("http:") or lora.startswith("https:")) and url != lora: loras_url_cache[local_name]=lora update = True new_loras_selected.append(url) @@ -6636,7 +6643,7 @@ def update_loras_url_cache(lora_dir, loras_selected): return new_loras_selected -def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, switch_type_if_compatible, min_settings_version = 0): +def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, switch_type_if_compatible, min_settings_version = 0, merge_loras = None): configs = None any_image_or_video = False if file_path.endswith(".json") and allow_json: @@ -6687,16 +6694,28 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw model_type = current_model_type if switch_type_if_compatible and are_model_types_compatible(model_type,current_model_type): model_type = current_model_type + old_loras_selected = old_loras_multipliers = None if merge_with_defaults: defaults = get_model_settings(state, model_type) defaults = get_default_settings(model_type) if defaults == None else defaults + if merge_loras is not None and model_type == current_model_type: + old_loras_selected, old_loras_multipliers = defaults.get("activated_loras", []), defaults.get("loras_multipliers", ""), defaults.update(configs) configs = defaults - loras_selected =configs.get("activated_loras", None) - if loras_selected is not None: - configs["activated_loras"]= update_loras_url_cache(get_lora_dir(model_type), loras_selected) + loras_selected =configs.get("activated_loras", []) + loras_multipliers = configs.get("loras_multipliers", "") + if loras_selected is not None and len(loras_selected) > 0: + loras_selected = update_loras_url_cache(get_lora_dir(model_type), loras_selected) + if old_loras_selected is not None: + if len(old_loras_selected) == 0 and "|" in loras_multipliers: + pass + else: + old_loras_selected = update_loras_url_cache(get_lora_dir(model_type), old_loras_selected) + loras_selected, loras_multipliers = merge_loras_settings(old_loras_selected, old_loras_multipliers, loras_selected, loras_multipliers, merge_loras ) + configs["activated_loras"]= loras_selected or [] + configs["loras_multipliers"] = loras_multipliers fix_settings(model_type, configs, min_settings_version) configs["model_type"] = model_type From 07e2935128fd9d09c303b9f0f3e2cfd145b16174 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Sun, 5 Oct 2025 18:22:31 +0200 Subject: [PATCH 070/155] fixed lynx weight not properly displayed --- README.md | 4 ++-- wgp.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index c4ea1f942..e7e9400a8 100644 --- a/README.md +++ b/README.md @@ -33,9 +33,9 @@ This new version hasn't any new model... I have created the new Discord Channel **share-your-settings** where you can post your *Settings* or *Lset files*. I will be pleased to add new Loras Accelerators in the list of WanGP *Accelerators Profiles if you post some good ones there. -*With the 8.993 update*, I have added out support for **Scaled FP8 format**. As a sample case, I have created finetunes for the **Wan 2.2 PalinGenesis** Finetune which is quite popular recently. You will find it in 3 flavors : *t2v*, *i2v* and *Lightning Accelerated for t2v*. +*With the 8.993 update*, I have added support for **Scaled FP8 format**. As a sample case, I have created finetunes for the **Wan 2.2 PalinGenesis** Finetune which is quite popular recently. You will find it in 3 flavors : *t2v*, *i2v* and *Lightning Accelerated for t2v*. -The *Scaled FP8 format* is very popular as it the format used by ... *ComfyUI*. So I except a flood of Finetunes in the *share-your-finetune* channel. If not it means this feature was useless and I will remove it 😈😈😈 +The *Scaled FP8 format* is widely used as it the format used by ... *ComfyUI*. So I except a flood of Finetunes in the *share-your-finetune* channel. If not it means this feature was useless and I will remove it 😈😈😈 Last but not least the Lora's documentation has been updated. diff --git a/wgp.py b/wgp.py index ff82489cf..49d91d788 100644 --- a/wgp.py +++ b/wgp.py @@ -3527,7 +3527,7 @@ def check(src, cond): control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "") if len(control_net_weight_alt_name) >0: if len(control_net_weight): control_net_weight += ", " - control_net_weight += control_net_weight_alt_name + "=" + str(configs.get("control_net_alt", 1)) + control_net_weight += control_net_weight_alt_name + "=" + str(configs.get("control_net_weight_alt", 1)) if len(control_net_weight) > 0: values += [control_net_weight] labels += ["Control Net Weights"] @@ -6323,7 +6323,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["frames_positions", "control_net_weight", "control_net_weight2"] if not len(model_def.get("control_net_weight_alt_name", "")) >0: - pop += ["control_net_alt"] + pop += ["control_net_weight_alt"] if model_def.get("video_guide_outpainting", None) is None: pop += ["video_guide_outpainting"] From 46759786e7e0f9c001c9244a574c3222e6712dbd Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 6 Oct 2025 11:46:35 +0200 Subject: [PATCH 071/155] fix lora profiles for animate --- defaults/animate.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/defaults/animate.json b/defaults/animate.json index d4676721a..3289197da 100644 --- a/defaults/animate.json +++ b/defaults/animate.json @@ -12,7 +12,7 @@ [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors" ], - "settings_dir": [ "wan" ], + "profiles_dir": [ "profiles_wan" ], "group": "wan2_2" } } \ No newline at end of file From 543923fac673d85067e964e0a42602955f2bdd00 Mon Sep 17 00:00:00 2001 From: Redtash1 Date: Mon, 6 Oct 2025 13:44:15 -0400 Subject: [PATCH 072/155] Added AMD Installation Instructions Link to Readme.md Added AMD Installation Instructions Link to Readme.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index e7e9400a8..21830e429 100644 --- a/README.md +++ b/README.md @@ -214,9 +214,14 @@ This automated script will: ## 📦 Installation +### Nvidia For detailed installation instructions for different GPU generations: - **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX +### AMD +For detailed installation instructions for different GPU generations: +- **[Installation Guide](docs/AMD-INSTALLATION.md)** - Complete setup instructions for Radeon RX 76XX, 77XX, 78XX & 79XX + ## 🎯 Usage ### Basic Usage From 1ce37d667f322be929c9315774829d413bad2f6c Mon Sep 17 00:00:00 2001 From: Redtash1 Date: Mon, 6 Oct 2025 13:49:42 -0400 Subject: [PATCH 073/155] Added AMD support to Readme.md --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 21830e429..92a881f78 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,8 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with: - Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models) -- Support for old GPUs (RTX 10XX, 20xx, ...) +- Support for old Nvidia GPUs (RTX 10XX, 20xx, ...) +- Support for AMD GPUs Radeon RX 76XX, 77XX, 78XX & 79XX, instructions in the Installation Section Below. - Very Fast on the latest GPUs - Easy to use Full Web based interface - Auto download of the required model adapted to your specific architecture From 61817be4517b0abf37e4b3dd93c435dd74b50a72 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 6 Oct 2025 20:43:41 +0200 Subject: [PATCH 074/155] Multi checkpoints folders support --- README.md | 5 +- defaults/animate.json | 1 - defaults/flux.json | 3 +- defaults/flux_chroma.json | 5 +- defaults/flux_dev_kontext.json | 5 +- defaults/flux_dev_umo.json | 3 +- defaults/flux_dev_uso.json | 5 +- defaults/flux_krea.json | 3 +- defaults/flux_schnell.json | 5 +- defaults/flux_srpo.json | 3 +- defaults/flux_srpo_uso.json | 5 +- defaults/vace_14B_2_2.json | 18 ++ defaults/vace_14B_cocktail_2_2.json | 2 +- defaults/vace_14B_lightning_3p_2_2.json | 2 +- defaults/vace_fun_14B_2_2.json | 2 +- .../Text2Video Cocktail - 10 Steps.json | 13 -- .../Text2Video FusioniX - 10 Steps.json | 11 -- ...x2v Cfg Step Distill Rank32 - 4 Steps.json | 11 -- ...un Inp HPS2 Reward 2 Phases - 4 Steps.json | 14 -- ...Fun Inp MPS Reward 2 Phases - 4 Steps.json | 14 -- .../Lightning v1.0 2 Phases - 4 Steps.json | 14 -- .../Lightning v1.0 3 Phases - 8 Steps.json | 18 -- .../Lightning v2025-09-28 - 4 Steps.json | 14 -- ...htning v2025-09-28 2 Phases - 4 Steps.json | 14 -- ...htning v2025-09-28 3 Phases - 8 Steps.json | 18 -- .../Text2Video FusioniX - 10 Steps.json | 12 -- loras/profiles_wan2_2_5B/FastWan 3 Steps.json | 7 - .../Image2Video FusioniX - 10 Steps.json | 10 -- models/flux/flux_handler.py | 34 ++-- models/flux/flux_main.py | 3 +- models/flux/util.py | 5 +- models/hyvideo/hunyuan.py | 16 +- models/hyvideo/hunyuan_handler.py | 7 +- models/ltx_video/ltxv.py | 25 ++- models/ltx_video/ltxv_handler.py | 5 +- .../qwen/configs}/qwen_image_20B.json | 0 models/qwen/qwen_handler.py | 6 +- models/qwen/qwen_main.py | 5 +- models/wan/any2video.py | 18 +- .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 950 bytes .../__pycache__/shared_config.cpython-310.pyc | Bin 0 -> 821 bytes .../__pycache__/wan_i2v_14B.cpython-310.pyc | Bin 0 -> 969 bytes .../__pycache__/wan_t2v_14B.cpython-310.pyc | Bin 0 -> 718 bytes .../__pycache__/wan_t2v_1_3B.cpython-310.pyc | Bin 0 -> 726 bytes {configs => models/wan/configs}/animate.json | 0 {configs => models/wan/configs}/fantasy.json | 0 .../wan/configs}/flf2v_720p.json | 0 {configs => models/wan/configs}/i2v.json | 0 {configs => models/wan/configs}/i2v_2_2.json | 0 .../wan/configs}/i2v_2_2_multitalk.json | 0 .../wan/configs}/infinitetalk.json | 0 .../wan/configs}/lucy_edit.json | 0 {configs => models/wan/configs}/lynx.json | 0 models/wan/configs/lynx_lite.json | 15 ++ .../wan/configs}/multitalk.json | 0 .../wan/configs}/phantom_1.3B.json | 0 .../wan/configs}/phantom_14B.json | 0 .../wan/configs}/sky_df_1.3.json | 0 .../wan/configs}/sky_df_14B.json | 0 {configs => models/wan/configs}/standin.json | 0 {configs => models/wan/configs}/t2v.json | 0 {configs => models/wan/configs}/t2v_1.3B.json | 0 models/wan/configs/t2v_2_2.json | 14 ++ {configs => models/wan/configs}/ti2v_2_2.json | 0 .../wan/configs}/vace_1.3B.json | 0 {configs => models/wan/configs}/vace_14B.json | 0 models/wan/configs/vace_14B_2_2.json | 16 ++ .../wan/configs}/vace_lynx_14B.json | 0 .../wan/configs}/vace_multitalk_14B.json | 0 .../wan/configs}/vace_standin_14B.json | 0 models/wan/fantasytalking/infer.py | 5 +- models/wan/multitalk/multitalk.py | 3 +- models/wan/wan_handler.py | 98 ++++++----- postprocessing/mmaudio/eval_utils.py | 7 +- .../mmaudio/model/utils/features_utils.py | 7 +- preprocessing/face_preprocessor.py | 3 +- preprocessing/matanyone/app.py | 3 +- .../matanyone/tools/base_segmenter.py | 3 +- wgp.py | 155 +++++++++--------- 79 files changed, 313 insertions(+), 377 deletions(-) create mode 100644 defaults/vace_14B_2_2.json delete mode 100644 loras/profiles_wan/Text2Video Cocktail - 10 Steps.json delete mode 100644 loras/profiles_wan/Text2Video FusioniX - 10 Steps.json delete mode 100644 loras/profiles_wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json delete mode 100644 loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json delete mode 100644 loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json delete mode 100644 loras/profiles_wan2_2/Lightning v1.0 2 Phases - 4 Steps.json delete mode 100644 loras/profiles_wan2_2/Lightning v1.0 3 Phases - 8 Steps.json delete mode 100644 loras/profiles_wan2_2/Lightning v2025-09-28 - 4 Steps.json delete mode 100644 loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json delete mode 100644 loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json delete mode 100644 loras/profiles_wan2_2/Text2Video FusioniX - 10 Steps.json delete mode 100644 loras/profiles_wan2_2_5B/FastWan 3 Steps.json delete mode 100644 loras_i2v/profiles_wan/Image2Video FusioniX - 10 Steps.json rename {configs => models/qwen/configs}/qwen_image_20B.json (100%) create mode 100644 models/wan/configs/__pycache__/__init__.cpython-310.pyc create mode 100644 models/wan/configs/__pycache__/shared_config.cpython-310.pyc create mode 100644 models/wan/configs/__pycache__/wan_i2v_14B.cpython-310.pyc create mode 100644 models/wan/configs/__pycache__/wan_t2v_14B.cpython-310.pyc create mode 100644 models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-310.pyc rename {configs => models/wan/configs}/animate.json (100%) rename {configs => models/wan/configs}/fantasy.json (100%) rename {configs => models/wan/configs}/flf2v_720p.json (100%) rename {configs => models/wan/configs}/i2v.json (100%) rename {configs => models/wan/configs}/i2v_2_2.json (100%) rename {configs => models/wan/configs}/i2v_2_2_multitalk.json (100%) rename {configs => models/wan/configs}/infinitetalk.json (100%) rename {configs => models/wan/configs}/lucy_edit.json (100%) rename {configs => models/wan/configs}/lynx.json (100%) create mode 100644 models/wan/configs/lynx_lite.json rename {configs => models/wan/configs}/multitalk.json (100%) rename {configs => models/wan/configs}/phantom_1.3B.json (100%) rename {configs => models/wan/configs}/phantom_14B.json (100%) rename {configs => models/wan/configs}/sky_df_1.3.json (100%) rename {configs => models/wan/configs}/sky_df_14B.json (100%) rename {configs => models/wan/configs}/standin.json (100%) rename {configs => models/wan/configs}/t2v.json (100%) rename {configs => models/wan/configs}/t2v_1.3B.json (100%) create mode 100644 models/wan/configs/t2v_2_2.json rename {configs => models/wan/configs}/ti2v_2_2.json (100%) rename {configs => models/wan/configs}/vace_1.3B.json (100%) rename {configs => models/wan/configs}/vace_14B.json (100%) create mode 100644 models/wan/configs/vace_14B_2_2.json rename {configs => models/wan/configs}/vace_lynx_14B.json (100%) rename {configs => models/wan/configs}/vace_multitalk_14B.json (100%) rename {configs => models/wan/configs}/vace_standin_14B.json (100%) diff --git a/README.md b/README.md index e7e9400a8..126e97d3c 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### October 5 2025: WanGP v8.993 - One Last Thing before the Big Unknown ... +### October 6 2025: WanGP v8.993 - A few last things before the Big Unknown ... This new version hasn't any new model... @@ -37,11 +37,14 @@ I have created the new Discord Channel **share-your-settings** where you can pos The *Scaled FP8 format* is widely used as it the format used by ... *ComfyUI*. So I except a flood of Finetunes in the *share-your-finetune* channel. If not it means this feature was useless and I will remove it 😈😈😈 +Not enough Space left on your SSD to download more models ? Would like to reuse Scaled FP8 files in your ComfyUI Folder without duplicating them ? Here comes *WanGP 8.994* **Multiple Checkpoints Folders** : you just need to move the files into different folders / hard drives or reuse existing folders and let know WanGP about it in the *Config Tab* and WanGP will be able to put all the parts together. + Last but not least the Lora's documentation has been updated. *update 8.991*: full power of *Vace Lynx* unleashed with new combinations such as Landscape + Face / Clothes + Face / Injectd Frame (Start/End frames/...) + Face *update 8.992*: optimized gen with Lora, should be 10% faster if many loras *update 8.993*: Support for *Scaled FP8* format and samples *Paligenesis* finetunes, merged Loras Accelerators and Non Accelerators +*update 8.994*: Added custom checkpoints folders ### September 30 2025: WanGP v8.9 - Combinatorics diff --git a/defaults/animate.json b/defaults/animate.json index 3289197da..bf45f4a92 100644 --- a/defaults/animate.json +++ b/defaults/animate.json @@ -12,7 +12,6 @@ [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors" ], - "profiles_dir": [ "profiles_wan" ], "group": "wan2_2" } } \ No newline at end of file diff --git a/defaults/flux.json b/defaults/flux.json index 87bab0fe3..724ec1abb 100644 --- a/defaults/flux.json +++ b/defaults/flux.json @@ -7,8 +7,7 @@ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev_quanto_bf16_int8.safetensors" ], - "image_outputs": true, - "flux-model": "flux-dev" + "image_outputs": true }, "prompt": "draw a hat", "resolution": "1280x720", diff --git a/defaults/flux_chroma.json b/defaults/flux_chroma.json index 7ffc42654..ebb8076be 100644 --- a/defaults/flux_chroma.json +++ b/defaults/flux_chroma.json @@ -1,14 +1,13 @@ { "model": { "name": "Flux 1 Chroma 1 HD 8.9B", - "architecture": "flux", + "architecture": "flux_chroma", "description": "FLUX.1 Chroma is a 8.9 billion parameters model. As a base model, Chroma1 is intentionally designed to be an excellent starting point for finetuning. It provides a strong, neutral foundation for developers, researchers, and artists to create specialized models..", "URLs": [ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-chroma_hd_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-chroma_hd_quanto_bf16_int8.safetensors" ], - "image_outputs": true, - "flux-model": "flux-chroma" + "image_outputs": true }, "prompt": "draw a hat", "resolution": "1280x720", diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json index 20b6bc4d3..cf3631a1f 100644 --- a/defaults/flux_dev_kontext.json +++ b/defaults/flux_dev_kontext.json @@ -1,13 +1,12 @@ { "model": { "name": "Flux 1 Dev Kontext 12B", - "architecture": "flux", + "architecture": "flux_dev_kontext", "description": "FLUX.1 Kontext is a 12 billion parameter rectified flow transformer capable of editing images based on instructions stored in the Prompt. Please be aware that Flux Kontext is picky on the resolution of the input image and the output dimensions may not match the dimensions of the input image.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" - ], - "flux-model": "flux-dev-kontext" + ] }, "prompt": "add a hat", "resolution": "1280x720", diff --git a/defaults/flux_dev_umo.json b/defaults/flux_dev_umo.json index 57164bb4c..ea7368053 100644 --- a/defaults/flux_dev_umo.json +++ b/defaults/flux_dev_umo.json @@ -1,10 +1,9 @@ { "model": { "name": "Flux 1 Dev UMO 12B", - "architecture": "flux", + "architecture": "flux_dev_umo", "description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.", "URLs": "flux", - "flux-model": "flux-dev-umo", "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"], "resolutions": [ ["1024x1024 (1:1)", "1024x1024"], ["768x1024 (3:4)", "768x1024"], diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json index 806dd7e02..fee8f25f9 100644 --- a/defaults/flux_dev_uso.json +++ b/defaults/flux_dev_uso.json @@ -1,12 +1,11 @@ { "model": { "name": "Flux 1 Dev USO 12B", - "architecture": "flux", + "architecture": "flux_dev_uso", "description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).", "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], "URLs": "flux", - "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"], - "flux-model": "flux-dev-uso" + "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"] }, "prompt": "the man is wearing a hat", "embedded_guidance_scale": 4, diff --git a/defaults/flux_krea.json b/defaults/flux_krea.json index 3caba1ab4..4fa0b3fb1 100644 --- a/defaults/flux_krea.json +++ b/defaults/flux_krea.json @@ -7,8 +7,7 @@ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-krea-dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-krea-dev_quanto_bf16_int8.safetensors" ], - "image_outputs": true, - "flux-model": "flux-dev" + "image_outputs": true }, "prompt": "draw a hat", "resolution": "1280x720", diff --git a/defaults/flux_schnell.json b/defaults/flux_schnell.json index d7abcde2c..1645a86ab 100644 --- a/defaults/flux_schnell.json +++ b/defaults/flux_schnell.json @@ -1,14 +1,13 @@ { "model": { "name": "Flux 1 Schnell 12B", - "architecture": "flux", + "architecture": "flux_schnell", "description": "FLUX.1 Schnell is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions. As a distilled model it requires fewer denoising steps.", "URLs": [ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-schnell_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-schnell_quanto_bf16_int8.safetensors" ], - "image_outputs": true, - "flux-model": "flux-schnell" + "image_outputs": true }, "prompt": "draw a hat", "resolution": "1280x720", diff --git a/defaults/flux_srpo.json b/defaults/flux_srpo.json index 59f07c617..188c162e4 100644 --- a/defaults/flux_srpo.json +++ b/defaults/flux_srpo.json @@ -6,8 +6,7 @@ "URLs": [ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_quanto_bf16_int8.safetensors" - ], - "flux-model": "flux-dev" + ] }, "prompt": "draw a hat", "resolution": "1024x1024", diff --git a/defaults/flux_srpo_uso.json b/defaults/flux_srpo_uso.json index ddfe50d63..e43371daf 100644 --- a/defaults/flux_srpo_uso.json +++ b/defaults/flux_srpo_uso.json @@ -1,12 +1,11 @@ { "model": { "name": "Flux 1 SRPO USO 12B", - "architecture": "flux", + "architecture": "flux_dev_uso", "description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process", "modules": [ "flux_dev_uso"], "URLs": "flux_srpo", - "loras": "flux_dev_uso", - "flux-model": "flux-dev-uso" + "loras": "flux_dev_uso" }, "prompt": "the man is wearing a hat", "embedded_guidance_scale": 4, diff --git a/defaults/vace_14B_2_2.json b/defaults/vace_14B_2_2.json new file mode 100644 index 000000000..6aa8cf8fb --- /dev/null +++ b/defaults/vace_14B_2_2.json @@ -0,0 +1,18 @@ +{ + "model": { + "name": "Wan2.2 Vace Experimental 14B", + "architecture": "vace_14B_2_2", + "modules": [ + "vace_14B" + ], + "description": "There is so far only PARTIAL support of Vace 2.1 which is currently used.", + "URLs": "t2v_2_2", + "URLs2": "t2v_2_2", + "visible": false + }, + "guidance_phases": 2, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold" : 875 +} \ No newline at end of file diff --git a/defaults/vace_14B_cocktail_2_2.json b/defaults/vace_14B_cocktail_2_2.json index 57839227e..a70424118 100644 --- a/defaults/vace_14B_cocktail_2_2.json +++ b/defaults/vace_14B_cocktail_2_2.json @@ -1,7 +1,7 @@ { "model": { "name": "Wan2.2 Vace Experimental Cocktail 14B", - "architecture": "vace_14B", + "architecture": "vace_14B_2_2", "modules": [ "vace_14B" ], diff --git a/defaults/vace_14B_lightning_3p_2_2.json b/defaults/vace_14B_lightning_3p_2_2.json index fca6cb816..00bca667a 100644 --- a/defaults/vace_14B_lightning_3p_2_2.json +++ b/defaults/vace_14B_lightning_3p_2_2.json @@ -1,7 +1,7 @@ { "model": { "name": "Wan2.2 Vace Lightning 3 Phases 14B", - "architecture": "vace_14B", + "architecture": "vace_14B_2_2", "modules": [ "vace_14B" ], diff --git a/defaults/vace_fun_14B_2_2.json b/defaults/vace_fun_14B_2_2.json index 8f22d3456..febf5ea81 100644 --- a/defaults/vace_fun_14B_2_2.json +++ b/defaults/vace_fun_14B_2_2.json @@ -1,7 +1,7 @@ { "model": { "name": "Wan2.2 Vace Fun 14B", - "architecture": "vace_14B", + "architecture": "vace_14B_2_2", "description": "This is the Fun Vace 2.2 version, that is not the official Vace 2.2", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_mbf16.safetensors", diff --git a/loras/profiles_wan/Text2Video Cocktail - 10 Steps.json b/loras/profiles_wan/Text2Video Cocktail - 10 Steps.json deleted file mode 100644 index 087e49c92..000000000 --- a/loras/profiles_wan/Text2Video Cocktail - 10 Steps.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" - ], - "loras_multipliers": "1 0.5 0.5 0.5", - "num_inference_steps": 10, - "guidance_phases": 1, - "guidance_scale": 1, - "flow_shift": 2 -} diff --git a/loras/profiles_wan/Text2Video FusioniX - 10 Steps.json b/loras/profiles_wan/Text2Video FusioniX - 10 Steps.json deleted file mode 100644 index 9fabf09eb..000000000 --- a/loras/profiles_wan/Text2Video FusioniX - 10 Steps.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" - ], - "loras_multipliers": "1", - "num_inference_steps": 10, - "guidance_phases": 1, - "guidance_scale": 1, - "flow_shift": 2 -} diff --git a/loras/profiles_wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json b/loras/profiles_wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json deleted file mode 100644 index 8ed2a4db0..000000000 --- a/loras/profiles_wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json +++ /dev/null @@ -1,11 +0,0 @@ -{ - - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors" - ], - "loras_multipliers": "1", - "num_inference_steps": 4, - "guidance_phases": 1, - "guidance_scale": 1, - "flow_shift": 2 -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json deleted file mode 100644 index 1b58ec73a..000000000 --- a/loras/profiles_wan2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "num_inference_steps": 8, - "guidance_scale": 1, - "guidance2_scale": 1, - "model_switch_phase": 1, - "switch_threshold": 876, - "guidance_phases": 2, - "flow_shift": 3, - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-high-noise-HPS2.1.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-low-noise-HPS2.1.safetensors" - ], - "loras_multipliers": "1;0 0;1" -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json deleted file mode 100644 index cc5742248..000000000 --- a/loras/profiles_wan2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "num_inference_steps": 4, - "guidance_scale": 1, - "guidance2_scale": 1, - "model_switch_phase": 1, - "switch_threshold": 876, - "guidance_phases": 2, - "flow_shift": 3, - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-high-noise-MPS.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-low-noise-MPS.safetensors" - ], - "loras_multipliers": "1;0 0;1" -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Lightning v1.0 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Lightning v1.0 2 Phases - 4 Steps.json deleted file mode 100644 index 51c903b63..000000000 --- a/loras/profiles_wan2_2/Lightning v1.0 2 Phases - 4 Steps.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "num_inference_steps": 4, - "guidance_scale": 1, - "guidance2_scale": 1, - "model_switch_phase": 1, - "switch_threshold": 876, - "guidance_phases": 2, - "flow_shift": 3, - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" - ], - "loras_multipliers": "1;0 0;1" -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Lightning v1.0 3 Phases - 8 Steps.json b/loras/profiles_wan2_2/Lightning v1.0 3 Phases - 8 Steps.json deleted file mode 100644 index a30834482..000000000 --- a/loras/profiles_wan2_2/Lightning v1.0 3 Phases - 8 Steps.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "num_inference_steps": 8, - "guidance_scale": 3.5, - "guidance2_scale": 1, - "guidance3_scale": 1, - "switch_threshold": 965, - "switch_threshold2": 800, - "guidance_phases": 3, - "model_switch_phase": 2, - "flow_shift": 3, - "sample_solver": "euler", - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" - ], - "loras_multipliers": "0;1;0 0;0;1", - "help": "This finetune uses the Lightning 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators." -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Lightning v2025-09-28 - 4 Steps.json b/loras/profiles_wan2_2/Lightning v2025-09-28 - 4 Steps.json deleted file mode 100644 index 31e4f0d66..000000000 --- a/loras/profiles_wan2_2/Lightning v2025-09-28 - 4 Steps.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "num_inference_steps": 4, - "guidance_scale": 1, - "guidance2_scale": 1, - "switch_threshold": 876, - "model_switch_phase": 1, - "guidance_phases": 2, - "flow_shift": 3, - "loras_multipliers": "1;0 0;1", - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" - ] -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json b/loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json deleted file mode 100644 index 31e4f0d66..000000000 --- a/loras/profiles_wan2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "num_inference_steps": 4, - "guidance_scale": 1, - "guidance2_scale": 1, - "switch_threshold": 876, - "model_switch_phase": 1, - "guidance_phases": 2, - "flow_shift": 3, - "loras_multipliers": "1;0 0;1", - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" - ] -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json b/loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json deleted file mode 100644 index 839e0564c..000000000 --- a/loras/profiles_wan2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "num_inference_steps": 8, - "guidance_scale": 3.5, - "guidance2_scale": 1, - "guidance3_scale": 1, - "switch_threshold": 965, - "switch_threshold2": 800, - "guidance_phases": 3, - "model_switch_phase": 2, - "flow_shift": 3, - "sample_solver": "euler", - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" - ], - "loras_multipliers": "0;1;0 0;0;1", - "help": "This finetune uses the Lightning 250928 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators." -} \ No newline at end of file diff --git a/loras/profiles_wan2_2/Text2Video FusioniX - 10 Steps.json b/loras/profiles_wan2_2/Text2Video FusioniX - 10 Steps.json deleted file mode 100644 index 1ad921b63..000000000 --- a/loras/profiles_wan2_2/Text2Video FusioniX - 10 Steps.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" - ], - "loras_multipliers": "1", - "num_inference_steps": 10, - "guidance_scale": 1, - "guidance2_scale": 1, - "guidance_phases": 2, - "flow_shift": 2 -} diff --git a/loras/profiles_wan2_2_5B/FastWan 3 Steps.json b/loras/profiles_wan2_2_5B/FastWan 3 Steps.json deleted file mode 100644 index 1e0383c61..000000000 --- a/loras/profiles_wan2_2_5B/FastWan 3 Steps.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "activated_loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], - "guidance_scale": 1, - "flow_shift": 3, - "num_inference_steps": 3, - "loras_multipliers": "1" -} \ No newline at end of file diff --git a/loras_i2v/profiles_wan/Image2Video FusioniX - 10 Steps.json b/loras_i2v/profiles_wan/Image2Video FusioniX - 10 Steps.json deleted file mode 100644 index 1fe4c14e5..000000000 --- a/loras_i2v/profiles_wan/Image2Video FusioniX - 10 Steps.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors" - ], - "loras_multipliers": "1", - "num_inference_steps": 10, - "guidance_phases": 1, - "guidance_scale": 1 -} diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index 471c339d4..cff1af4cf 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -1,15 +1,35 @@ import torch +from shared.utils import files_locator as fl def get_ltxv_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" + text_encoder_filename = "T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" if text_encoder_quantization =="int8": text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") - return text_encoder_filename + return fl.locate_file(text_encoder_filename, True) class family_handler(): @staticmethod + def query_supported_types(): + return ["flux", "flux_chroma", "flux_dev_kontext", "flux_dev_umo", "flux_dev_uso", "flux_schnell" ] + + @staticmethod + def query_family_maps(): + + models_eqv_map = { + "flux_dev_kontext" : "flux", + "flux_dev_umo" : "flux", + "flux_dev_uso" : "flux", + "flux_schnell" : "flux", + "flux_chroma" : "flux", + } + + models_comp_map = { + "flux": ["flux_chroma", "flux_dev_kontext", "flux_dev_umo", "flux_dev_uso", "flux_schnell" ] + } + return models_eqv_map, models_comp_map + @staticmethod def query_model_def(base_model_type, model_def): - flux_model = model_def.get("flux-model", "flux-dev") + flux_model = "flux-dev" if base_model_type == "flux" else base_model_type.replace("_", "-") flux_schnell = flux_model == "flux-schnell" flux_chroma = flux_model == "flux-chroma" flux_uso = flux_model == "flux-dev-uso" @@ -19,6 +39,7 @@ def query_model_def(base_model_type, model_def): extra_model_def = { "image_outputs" : True, "no_negative_prompt" : not flux_chroma, + "flux-model": flux_model, } if flux_chroma: extra_model_def["guidance_max_phases"] = 1 @@ -60,13 +81,6 @@ def query_model_def(base_model_type, model_def): return extra_model_def - @staticmethod - def query_supported_types(): - return ["flux"] - - @staticmethod - def query_family_maps(): - return {}, {} @staticmethod def get_rgb_factors(base_model_type ): diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 6746a2375..7f18d9d36 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -12,6 +12,7 @@ import torchvision.transforms.functional as TVF import math from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image +from shared.utils import files_locator as fl from .util import ( aspect_ratio_to_height_width, @@ -101,7 +102,7 @@ def __init__( siglip_processor = siglip_model = feature_embedder = None if self.name == 'flux-dev-uso': - siglip_path = "ckpts/siglip-so400m-patch14-384" + siglip_path = fl.locate_folder("siglip-so400m-patch14-384") siglip_processor = SiglipImageProcessor.from_pretrained(siglip_path) siglip_model = SiglipVisionModel.from_pretrained(siglip_path) siglip_model.eval().to("cpu") diff --git a/models/flux/util.py b/models/flux/util.py index af75f6269..909d6b4ea 100644 --- a/models/flux/util.py +++ b/models/flux/util.py @@ -14,6 +14,7 @@ from .model import Flux, FluxLoraWrapper, FluxParams from .modules.autoencoder import AutoEncoder, AutoEncoderParams from .modules.conditioner import HFEmbedder +from shared.utils import files_locator as fl CHECKPOINTS_DIR = Path("checkpoints") @@ -756,12 +757,12 @@ def load_t5(device: str | torch.device = "cuda", text_encoder_filename = None, m def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: - return HFEmbedder("ckpts/clip_vit_large_patch14", "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device) + return HFEmbedder( fl.locate_folder("clip_vit_large_patch14"), "", max_length=77, torch_dtype=torch.bfloat16, is_clip =True).to(device) def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder: config = configs[name] - ckpt_path = str(get_checkpoint_path(config.repo_id, config.repo_ae, "FLUX_AE")) + ckpt_path = str(get_checkpoint_path(config.repo_id, fl.locate_file("flux_vae.safetensors"), "FLUX_AE")) # Loading the autoencoder with torch.device("meta"): diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py index aa6c3b3e6..91aaf7fdd 100644 --- a/models/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -3,7 +3,7 @@ import random import functools from typing import List, Optional, Tuple, Union - +from shared.utils import files_locator as fl from pathlib import Path from einops import rearrange import torch @@ -405,14 +405,14 @@ def from_pretrained(cls, model_filepath, model_type, base_model_type, text_encod # ============================= Build extra models ======================== # VAE if custom or avatar: - vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json" - vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors" + vae_configpath = fl.locate_file("hunyuan_video_custom_VAE_config.json") + vae_filepath = fl.locate_file("hunyuan_video_custom_VAE_fp32.safetensors") # elif avatar: # vae_configpath = "ckpts/config_vae_avatar.json" # vae_filepath = "ckpts/vae_avatar.pt" else: - vae_configpath = "ckpts/hunyuan_video_VAE_config.json" - vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors" + vae_configpath = fl.locate_file("hunyuan_video_VAE_config.json") + vae_filepath = fl.locate_file("hunyuan_video_VAE_fp32.safetensors") # config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json") # config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json") @@ -486,12 +486,12 @@ def from_pretrained(cls, model_filepath, model_type, base_model_type, text_encod align_instance = None if avatar or custom_audio: - feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/") - wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32) + feature_extractor = AutoFeatureExtractor.from_pretrained(fl.locate_folder("whisper-tiny")) + wav2vec = WhisperModel.from_pretrained(fl.locate_folder("whisper-tiny")).to(device="cpu", dtype=torch.float32) wav2vec._model_dtype = torch.float32 wav2vec.requires_grad_(False) if avatar: - align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt") + align_instance = AlignImage("cuda", det_path= fl.locate_file("det_align/detface.pt")) align_instance.facedet.model.to("cpu") adapt_model(model, "audio_adapter_blocks") elif custom_audio: diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index ebe07d26d..4a0c035d5 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -1,12 +1,13 @@ import torch +from shared.utils import files_locator as fl def get_hunyuan_text_encoder_filename(text_encoder_quantization): if text_encoder_quantization =="int8": - text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors" + text_encoder_filename = "llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_quanto_int8.safetensors" else: - text_encoder_filename = "ckpts/llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" + text_encoder_filename = "llava-llama-3-8b/llava-llama-3-8b-v1_1_vlm_fp16.safetensors" - return text_encoder_filename + return fl.locate_file(text_encoder_filename, True) class family_handler(): diff --git a/models/ltx_video/ltxv.py b/models/ltx_video/ltxv.py index 080860c14..7959c67cb 100644 --- a/models/ltx_video/ltxv.py +++ b/models/ltx_video/ltxv.py @@ -38,6 +38,7 @@ from .models.autoencoders.latent_upsampler import LatentUpsampler from .pipelines import crf_compressor import cv2 +from shared.utils import files_locator as fl MAX_HEIGHT = 720 MAX_WIDTH = 1280 @@ -177,7 +178,7 @@ def __init__( # offload.save_model(transformer, "ckpts/ltxv_0.9.8_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path= "c:/temp/ltxv_config.json") # vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) - vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.7_VAE.safetensors", modelClass=CausalVideoAutoencoder) + vae = offload.fast_load_transformers_model( fl.locate_file("ltxv_0.9.7_VAE.safetensors"), modelClass=CausalVideoAutoencoder) # vae = offload.fast_load_transformers_model("ckpts/ltxv_0.9.8_VAE.safetensors", modelClass=CausalVideoAutoencoder) # if VAE_dtype == torch.float16: VAE_dtype = torch.bfloat16 @@ -196,11 +197,11 @@ def __init__( transformer._lock_dtype = torch.float - scheduler = RectifiedFlowScheduler.from_pretrained("ckpts/ltxv_scheduler.json") + scheduler = RectifiedFlowScheduler.from_pretrained(fl.locate_file("ltxv_scheduler.json")) # transformer = offload.fast_load_transformers_model("ltx_13B_quanto_bf16_int8.safetensors", modelClass=Transformer3DModel, modelPrefix= "model.diffusion_model", forcedConfigPath="config_transformer.json") # offload.save_model(transformer, "ltx_13B_quanto_bf16_int8.safetensors", do_quantize= True, config_file_path="config_transformer.json") - latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.7_spatial_upscaler.safetensors").to("cpu").eval() + latent_upsampler = LatentUpsampler.from_pretrained(fl.locate_file("ltxv_0.9.7_spatial_upscaler.safetensors")).to("cpu").eval() # latent_upsampler = LatentUpsampler.from_pretrained("ckpts/ltxv_0.9.8_spatial_upscaler.safetensors").to("cpu").eval() latent_upsampler.to(VAE_dtype) latent_upsampler._model_dtype = VAE_dtype @@ -216,19 +217,13 @@ def __init__( text_encoder = offload.fast_load_transformers_model(text_encoder_filepath) patchifier = SymmetricPatchifier(patch_size=1) - tokenizer = T5Tokenizer.from_pretrained( "ckpts/T5_xxl_1.1") + tokenizer = T5Tokenizer.from_pretrained(fl.locate_folder("T5_xxl_1.1")) enhance_prompt = False - if enhance_prompt: - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) - prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2_quanto_bf16_int8.safetensors") - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2") - else: - prompt_enhancer_image_caption_model = None - prompt_enhancer_image_caption_processor = None - prompt_enhancer_llm_model = None - prompt_enhancer_llm_tokenizer = None + prompt_enhancer_image_caption_model = None + prompt_enhancer_image_caption_processor = None + prompt_enhancer_llm_model = None + prompt_enhancer_llm_tokenizer = None # if prompt_enhancer_image_caption_model != None: # pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model @@ -489,7 +484,7 @@ def get_loras_transformer(self, get_model_recursive_prop, model_type, video_prom if letter in video_prompt_type: for file_name in preloadURLs: if signature in file_name: - loras += [ os.path.join("ckpts", os.path.basename(file_name))] + loras += [ fl.locate_file(os.path.basename(file_name))] break loras_mult = [1.] * len(loras) return loras, loras_mult diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index 8c322e153..6386f68ce 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -1,11 +1,12 @@ import torch +from shared.utils import files_locator as fl def get_ltxv_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" + text_encoder_filename = "T5_xxl_1.1/T5_xxl_1.1_enc_bf16.safetensors" if text_encoder_quantization =="int8": text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") - return text_encoder_filename + return fl.locate_file(text_encoder_filename, True) class family_handler(): @staticmethod diff --git a/configs/qwen_image_20B.json b/models/qwen/configs/qwen_image_20B.json similarity index 100% rename from configs/qwen_image_20B.json rename to models/qwen/configs/qwen_image_20B.json diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 99864a5a7..2a3c8487a 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -1,12 +1,13 @@ import torch import gradio as gr +from shared.utils import files_locator as fl def get_qwen_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors" + text_encoder_filename = "Qwen2.5-VL-7B-Instruct/Qwen2.5-VL-7B-Instruct_bf16.safetensors" if text_encoder_quantization =="int8": text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_bf16_int8") - return text_encoder_filename + return fl.locate_file(text_encoder_filename, True) class family_handler(): @staticmethod @@ -18,6 +19,7 @@ def query_model_def(base_model_type, model_def): ("Lightning", "lightning")], "guidance_max_phases" : 1, "fit_into_canvas_image_refs": 0, + "profiles_dir": ["qwen"], } if base_model_type in ["qwen_image_edit_20B", "qwen_image_edit_plus_20B"]: diff --git a/models/qwen/qwen_main.py b/models/qwen/qwen_main.py index 8c4514fa2..bf2733851 100644 --- a/models/qwen/qwen_main.py +++ b/models/qwen/qwen_main.py @@ -18,6 +18,7 @@ from .pipeline_qwenimage import QwenImagePipeline from PIL import Image from shared.utils.utils import calculate_new_dimensions, convert_tensor_to_image +from shared.utils import files_locator as fl def stitch_images(img1, img2): # Resize img2 to match img1's height @@ -56,7 +57,7 @@ def __init__( tokenizer = AutoTokenizer.from_pretrained(os.path.join(checkpoint_dir,"Qwen2.5-VL-7B-Instruct")) self.base_model_type = base_model_type - base_config_file = "configs/qwen_image_20B.json" + base_config_file = "models/qwen/configs/qwen_image_20B.json" with open(base_config_file, 'r', encoding='utf-8') as f: transformer_config = json.load(f) transformer_config.pop("_diffusers_version") @@ -223,6 +224,6 @@ def get_loras_transformer(self, get_model_recursive_prop, model_type, model_mode if model_mode == 0: return [], [] preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") if len(preloadURLs) == 0: return [], [] - return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] + return [ fl.locate_file(os.path.basename(preloadURLs[0]))] , [1] diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 01c2cd832..b2ee944e7 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -37,6 +37,7 @@ from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask from shared.utils.audio_video import save_video from mmgp import safetensors2 +from shared.utils import files_locator as fl def optimized_scale(positive_flat, negative_flat): @@ -91,16 +92,15 @@ def __init__( dtype=config.t5_dtype, device=torch.device('cpu'), checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, "umt5-xxl"), + tokenizer_path=fl.locate_folder("umt5-xxl"), shard_fn= None) # base_model_type = "i2v2_2" if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]: self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, - checkpoint_path=os.path.join(checkpoint_dir , - config.clip_checkpoint), - tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large")) + checkpoint_path=fl.locate_file(config.clip_checkpoint), + tokenizer_path=fl.locate_folder("xlm-roberta-large")) ignore_unused_weights = model_def.get("ignore_unused_weights", False) @@ -114,9 +114,7 @@ def __init__( vae = WanVAE self.patch_size = config.patch_size - self.vae = vae( - vae_pth=os.path.join(checkpoint_dir, vae_checkpoint), dtype= VAE_dtype, - device="cpu") + self.vae = vae( vae_pth=fl.locate_file(vae_checkpoint), dtype= VAE_dtype, device="cpu") self.vae.device = self.device # config_filename= "configs/t2v_1.3B.json" @@ -125,7 +123,7 @@ def __init__( # config = json.load(f) # sd = safetensors2.torch_load_file(xmodel_filename) # model_filename = "c:/temp/wan2.2i2v/low/diffusion_pytorch_model-00001-of-00006.safetensors" - base_config_file = f"configs/{base_model_type}.json" + base_config_file = f"models/wan/configs/{base_model_type}.json" forcedConfigPath = base_config_file if len(model_filename) > 1 else None # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename @@ -702,7 +700,7 @@ def generate(self, if True: with init_empty_weights(): arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 ) - offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) + offload.load_model_data(arc_resampler, fl.locate_file("wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) arc_resampler.to(self.device) arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float) ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype) @@ -1125,6 +1123,6 @@ def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model if "#" in video_prompt_type and "1" in video_prompt_type: preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") if len(preloadURLs) > 0: - return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] + return [fl.locate_file(os.path.basename(preloadURLs[0]))] , [1] return [], [] diff --git a/models/wan/configs/__pycache__/__init__.cpython-310.pyc b/models/wan/configs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fdfaf60617346093216ab70837ec41ba57162a2 GIT binary patch literal 950 zcmZ8f&rjPh6t>+oantk%Eg+6u+X+NNr6GhS0_#wkwb`1YKJb2z{H(8_W@{ZGWTCGN_>F!T3h$=WX zJo}v2>N^1Gse*i6VI#ryopE`06dZf6HrH}2kYR%qQn^I+~ z;7*M?fQ@x~YJjPNGtSi3?0Quh4;AysLrXfUXkktj%_~z-rl?Aidsw@}ciM60yg!i1 zX(|h`HW-~;#AFo3h78WVa1^lY`+mT08WzH-5VI6>+qY0Xo|)*TeW3wNOofyusv1+R zQQeLy=RtocIcM1;r_FZ{+RvJMF0)*?sDEB@-+9HIZo|1{;`Hg+i7#XAoxF%x!F|!` z9CQyGPq|8S+1yKKw~^r61h$Y+|JPyjGv&klPbgf1?IvdWNN)D_TRpgSd#}?Qj)e51 z9vrqeNZv!#ZcS>t^x0~HR^S(DblNlb4A{bl0f22N# z5&FrB6+v-1ZL^Tx-d*W*T1)(mpJE@7n-e zdy`Z;`D-VD*~Yb#px$6PP->R}uX^a(6|%DGPwg|hSs>qA^spJstpIiH@Cj0Dj1W~b?B9Rvb%avZ2qs) z8LOvPEn+u%mG`Dd>!?@RZon__&sXVm1G;DK9V5E6MBdB;?H-TWEi_@{#j(Bmw0rag z6k1&LYI~6BS|zW(O=r&lcSZO3n4HeT5kMh@w|>1qzwM6r3G-Zp<6(|)QSL7d C<&vZT literal 0 HcmV?d00001 diff --git a/models/wan/configs/__pycache__/wan_i2v_14B.cpython-310.pyc b/models/wan/configs/__pycache__/wan_i2v_14B.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8b61e52cf614e0fa5157edef10f1cf3f8332edd GIT binary patch literal 969 zcmZXS%W4!s6o$LH=hEr9<*v!bES*J6Gr>R#`m1yPK6R?=uTC2c$H4L7*Y}`OGK?QOxj8IM z9^oTDA;16<8J?cah{UF65(6wKMOIw$O2ptMl%aA-yfQCeRe0$m!S_{Ord6xdxhbs399rM>e5>Y(W#b2rc9iEFza-3AqBx$kl>tu!8YAtRn9eybEg>Zvc5? zwAZKjt;VZ1$?YR1$4^6FO$jO}G|yQ=Wxxan>h}kel_zP^4+r~O?^v?+eEWE-yK|`5 zQJS!rQ#!d9r@$jg)h->yYS+EEh+Lld?n%G9*U3~cv9D!s!;iv@UPLh!=?NE#(X&v| zfOdB#$Edn<)s2`K@CWn0#hNtZiK`>`R-KDI++Z}|Zg*$4sbqbwcp_6FCrh{LCrv!b zcBe~^A0O%OYg<#iR%2=-L{*z$34NziQwv%Rja7|3)C_cKdS4fr)4w=G4db_i%b#Db zet-LN(B90;Diwa<*&KZZ^i}9KUgO%)!c3tz%&Xz{86N#GuZ%KaihB*T&=0sj%+fGX zUL8}F4tWwzxbU39hr*M*+K*DEx_fgG)Hvf_t0=g!zF7$WJ2cN2zuDi>xl|$qo;w*+ zen2H|m0J+TxJ5sqIwQCXk-b)>co4G)YCAw|tRG zNtsf~Ot5$+KO15bPhQ|phwt}ZzUb-8c=ozCyX+qRd_BMLNa~o*E$U?B+69E&F)O5E*#wF6{F!BQI^-WbA|X=% literal 0 HcmV?d00001 diff --git a/models/wan/configs/__pycache__/wan_t2v_14B.cpython-310.pyc b/models/wan/configs/__pycache__/wan_t2v_14B.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97ace1355070f36e8bcb71a0acf1c06e4ec947af GIT binary patch literal 718 zcmYk3%Wl&^6ozMFH+GyjX`8Mg3l?ONNQr0{6{@&ZYycrq6)B1|GCt$jif^59nkH+u zeFapO?AY)MyaLi}S$V+@2?-&=8HYm0@~8itxt;$^=DH5S@$UPV@K=qHuR6Fn3=AIN zFW(}70F6mPCzKMXff3gdV`5OkDVVTzNhc;ZueP}HUeDLL$?9MsEwGU`G>{E&kPbAF zO=uxo;38dUBij`_&_TTo+sK`YyRd`03%jpKzguD@q)(N7?8()W$S+E|Nwk+<%4FyX z4$Kc`>+a($4Wij$@3oiqPR6Ib(Y;4{A7iPPaK_fRlMHw)S+UO+NwGh;xQGWl^#^Z) z(ZMh;!nJ*!Ubpd#@o>aWA0F$OzEPTZX=x*brLQ-HZdHjrBtmB+vZrx?IfC_-clDxL z`5T95qCS~je*1LwaSQ9N@N>uyr_Y~F zbrYYyoNn4a#pjy(l|!jD`)hERuatcs*DPXwq@PyaL9}SA=9mtx8}^tYycWU1h&&Z<;1& zu6+emj@-EL3cLc6ublS44GD=0j1vVqmVX-m$M($Fnd{mF?ZeM+p=%KGT`%q$YUD9a z`5plTXiO41qLe@djJT2*BZCr7!Gx_FIx@L=yTy%>zhm+=X5A_h9c0>9$L(gmkI0&OEtz7WqX<*FtmdrA&sN z;K2M~vTQ%i(jb}~ciwtw=X`k48T1cM^+Co`FX4TF`1lsv7rEBUoVR=;%E} zw^IH=6IwW*-2C`*`|IoHlULogvN(1G*pbk;D0kD#stuv@DXSP>u>m_cRn=JzUcr@% z#rq-mr+F5og>rCNWK*6-ODt`TA__a>~+7B&y|$OiJbzMY|^08{C>!f z$1h)u^;157JzhWl7=LZ-uUeGaXPXZD`9fKbaMLX2C;E%!1B64Jif!7oYE=JEivGW0 F)4v0&&vF0& literal 0 HcmV?d00001 diff --git a/configs/animate.json b/models/wan/configs/animate.json similarity index 100% rename from configs/animate.json rename to models/wan/configs/animate.json diff --git a/configs/fantasy.json b/models/wan/configs/fantasy.json similarity index 100% rename from configs/fantasy.json rename to models/wan/configs/fantasy.json diff --git a/configs/flf2v_720p.json b/models/wan/configs/flf2v_720p.json similarity index 100% rename from configs/flf2v_720p.json rename to models/wan/configs/flf2v_720p.json diff --git a/configs/i2v.json b/models/wan/configs/i2v.json similarity index 100% rename from configs/i2v.json rename to models/wan/configs/i2v.json diff --git a/configs/i2v_2_2.json b/models/wan/configs/i2v_2_2.json similarity index 100% rename from configs/i2v_2_2.json rename to models/wan/configs/i2v_2_2.json diff --git a/configs/i2v_2_2_multitalk.json b/models/wan/configs/i2v_2_2_multitalk.json similarity index 100% rename from configs/i2v_2_2_multitalk.json rename to models/wan/configs/i2v_2_2_multitalk.json diff --git a/configs/infinitetalk.json b/models/wan/configs/infinitetalk.json similarity index 100% rename from configs/infinitetalk.json rename to models/wan/configs/infinitetalk.json diff --git a/configs/lucy_edit.json b/models/wan/configs/lucy_edit.json similarity index 100% rename from configs/lucy_edit.json rename to models/wan/configs/lucy_edit.json diff --git a/configs/lynx.json b/models/wan/configs/lynx.json similarity index 100% rename from configs/lynx.json rename to models/wan/configs/lynx.json diff --git a/models/wan/configs/lynx_lite.json b/models/wan/configs/lynx_lite.json new file mode 100644 index 000000000..11618e768 --- /dev/null +++ b/models/wan/configs/lynx_lite.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "lynx": "lite" +} diff --git a/configs/multitalk.json b/models/wan/configs/multitalk.json similarity index 100% rename from configs/multitalk.json rename to models/wan/configs/multitalk.json diff --git a/configs/phantom_1.3B.json b/models/wan/configs/phantom_1.3B.json similarity index 100% rename from configs/phantom_1.3B.json rename to models/wan/configs/phantom_1.3B.json diff --git a/configs/phantom_14B.json b/models/wan/configs/phantom_14B.json similarity index 100% rename from configs/phantom_14B.json rename to models/wan/configs/phantom_14B.json diff --git a/configs/sky_df_1.3.json b/models/wan/configs/sky_df_1.3.json similarity index 100% rename from configs/sky_df_1.3.json rename to models/wan/configs/sky_df_1.3.json diff --git a/configs/sky_df_14B.json b/models/wan/configs/sky_df_14B.json similarity index 100% rename from configs/sky_df_14B.json rename to models/wan/configs/sky_df_14B.json diff --git a/configs/standin.json b/models/wan/configs/standin.json similarity index 100% rename from configs/standin.json rename to models/wan/configs/standin.json diff --git a/configs/t2v.json b/models/wan/configs/t2v.json similarity index 100% rename from configs/t2v.json rename to models/wan/configs/t2v.json diff --git a/configs/t2v_1.3B.json b/models/wan/configs/t2v_1.3B.json similarity index 100% rename from configs/t2v_1.3B.json rename to models/wan/configs/t2v_1.3B.json diff --git a/models/wan/configs/t2v_2_2.json b/models/wan/configs/t2v_2_2.json new file mode 100644 index 000000000..c554e7302 --- /dev/null +++ b/models/wan/configs/t2v_2_2.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512 +} diff --git a/configs/ti2v_2_2.json b/models/wan/configs/ti2v_2_2.json similarity index 100% rename from configs/ti2v_2_2.json rename to models/wan/configs/ti2v_2_2.json diff --git a/configs/vace_1.3B.json b/models/wan/configs/vace_1.3B.json similarity index 100% rename from configs/vace_1.3B.json rename to models/wan/configs/vace_1.3B.json diff --git a/configs/vace_14B.json b/models/wan/configs/vace_14B.json similarity index 100% rename from configs/vace_14B.json rename to models/wan/configs/vace_14B.json diff --git a/models/wan/configs/vace_14B_2_2.json b/models/wan/configs/vace_14B_2_2.json new file mode 100644 index 000000000..e48a816a3 --- /dev/null +++ b/models/wan/configs/vace_14B_2_2.json @@ -0,0 +1,16 @@ +{ + "_class_name": "VaceWanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], + "vace_in_dim": 96 +} diff --git a/configs/vace_lynx_14B.json b/models/wan/configs/vace_lynx_14B.json similarity index 100% rename from configs/vace_lynx_14B.json rename to models/wan/configs/vace_lynx_14B.json diff --git a/configs/vace_multitalk_14B.json b/models/wan/configs/vace_multitalk_14B.json similarity index 100% rename from configs/vace_multitalk_14B.json rename to models/wan/configs/vace_multitalk_14B.json diff --git a/configs/vace_standin_14B.json b/models/wan/configs/vace_standin_14B.json similarity index 100% rename from configs/vace_standin_14B.json rename to models/wan/configs/vace_standin_14B.json diff --git a/models/wan/fantasytalking/infer.py b/models/wan/fantasytalking/infer.py index d96bea037..a75641789 100644 --- a/models/wan/fantasytalking/infer.py +++ b/models/wan/fantasytalking/infer.py @@ -5,6 +5,7 @@ from .model import FantasyTalkingAudioConditionModel from .utils import get_audio_features import gc, torch +from shared.utils import files_locator as fl def parse_audio(audio_path, start_frame, num_frames, fps = 23, device = "cuda"): fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device) @@ -16,10 +17,10 @@ def parse_audio(audio_path, start_frame, num_frames, fps = 23, device = "cuda"): with init_empty_weights(): proj_model = AudioProjModel( 768, 2048) - offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors") + offload.load_model_data(proj_model, fl.locate_file("fantasy_proj_model.safetensors")) proj_model.to("cpu").eval().requires_grad_(False) - wav2vec_model_dir = "ckpts/wav2vec" + wav2vec_model_dir = fl.locate_folder("wav2vec") wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir) wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False) wav2vec.to(device) diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py index 52d2dd995..845e39b99 100644 --- a/models/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -19,6 +19,7 @@ import soundfile as sf import re import math +from shared.utils import files_locator as fl def custom_init(device, wav2vec): audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device) @@ -215,7 +216,7 @@ def process_tts_multi(text, save_dir, voice1, voice2): def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False): - wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") + wav2vec_feature_extractor, audio_encoder= custom_init('cpu', fl.locate_folder("chinese-wav2vec2-base")) # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") pad = int(padded_frames_for_embeddings/ fps * sr) new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration ) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index a5d1e3d98..dfc416e6b 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -1,6 +1,7 @@ import torch import numpy as np import gradio as gr +from shared.utils import files_locator as fl def test_class_i2v(base_model_type): return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ] @@ -22,7 +23,44 @@ def test_lynx(base_model_type): def test_wan_5B(base_model_type): return base_model_type in ["ti2v_2_2", "lucy_edit"] + class family_handler(): + @staticmethod + def query_supported_types(): + return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_14B_2_2", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B", + "t2v_1.3B", "standin", "lynx_lite", "lynx", "t2v", "t2v_2_2", "vace_1.3B", "phantom_1.3B", "phantom_14B", + "recam_1.3B", "animate", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + + + @staticmethod + def query_family_maps(): + + models_eqv_map = { + "flf2v_720p" : "i2v", + "t2v_1.3B" : "t2v", + "t2v_2_2" : "t2v", + "vace_standin_14B" : "vace_14B", + "vace_lynx_14B" : "vace_14B", + "vace_14B_2_2": "vace_14B", + } + + models_comp_map = { + "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B", "vace_14B_2_2"], + "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B", "vace_14B_2_2", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx"], + "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], + "i2v_2_2" : ["i2v_2_2_multitalk"], + "fantasy": ["multitalk"], + } + return models_eqv_map, models_comp_map + + @staticmethod + def query_model_family(): + return "wan" + + @staticmethod + def query_family_infos(): + return {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2") } @staticmethod def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_steps_cache): @@ -71,10 +109,10 @@ def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_st @staticmethod def get_wan_text_encoder_filename(text_encoder_quantization): - text_encoder_filename = "ckpts/umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" + text_encoder_filename = "umt5-xxl/models_t5_umt5-xxl-enc-bf16.safetensors" if text_encoder_quantization =="int8": text_encoder_filename = text_encoder_filename.replace("bf16", "quanto_int8") - return text_encoder_filename + return fl.locate_file(text_encoder_filename, True) @staticmethod def query_model_def(base_model_type, model_def): @@ -86,8 +124,24 @@ def query_model_def(base_model_type, model_def): extra_model_def["multitalk_class"] = test_multitalk(base_model_type) extra_model_def["standin_class"] = test_standin(base_model_type) extra_model_def["lynx_class"] = test_lynx(base_model_type) - vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] + vace_class = base_model_type in ["vace_14B", "vace_14B_2_2", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] extra_model_def["vace_class"] = vace_class + group = "wan" + if base_model_type in ["t2v_2_2", "i2v_2_2", "vace_14B_2_2"]: + profiles_dir = "wan_2_2" + group = "wan_2_2" + elif i2v: + profiles_dir = "wan_i2v" + elif test_wan_5B(base_model_type): + profiles_dir = "wan_2_2_5B" + group = "wan_2_2" + elif test_class_1_3B(base_model_type): + profiles_dir = "wan_1.3B" + else: + profiles_dir = "wan" + + extra_model_def["profiles_dir"] = [profiles_dir] + extra_model_def["group"] = group if base_model_type in ["animate"]: fps = 30 @@ -108,7 +162,7 @@ def query_model_def(base_model_type, model_def): extra_model_def.update({ "frames_minimum" : frames_minimum, "frames_steps" : frames_steps, - "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2", + "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "t2v_2_2", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2", "multiple_submodels" : multiple_submodels, "guidance_max_phases" : 3, "skip_layer_guidance" : True, @@ -127,7 +181,7 @@ def query_model_def(base_model_type, model_def): }) - if base_model_type in ["t2v"]: + if base_model_type in ["t2v", "t2v_2_2"]: extra_model_def["guide_custom_choices"] = { "choices":[("Use Text Prompt Only", ""), ("Video to Video guided by Text Prompt", "GUV"), @@ -363,40 +417,6 @@ def query_model_def(base_model_type, model_def): return extra_model_def - @staticmethod - def query_supported_types(): - return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B", - "t2v_1.3B", "standin", "lynx_lite", "lynx", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", "animate", - "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] - - - @staticmethod - def query_family_maps(): - - models_eqv_map = { - "flf2v_720p" : "i2v", - "t2v_1.3B" : "t2v", - "vace_standin_14B" : "vace_14B", - "vace_lynx_14B" : "vace_14B", - } - - models_comp_map = { - "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B"], - "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx"], - "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], - "i2v_2_2" : ["i2v_2_2_multitalk"], - "fantasy": ["multitalk"], - } - return models_eqv_map, models_comp_map - - @staticmethod - def query_model_family(): - return "wan" - - @staticmethod - def query_family_infos(): - return {"wan":(0, "Wan2.1"), "wan2_2":(1, "Wan2.2") } @staticmethod def get_vae_block_size(base_model_type): diff --git a/postprocessing/mmaudio/eval_utils.py b/postprocessing/mmaudio/eval_utils.py index 11dc654c2..46e0670ae 100644 --- a/postprocessing/mmaudio/eval_utils.py +++ b/postprocessing/mmaudio/eval_utils.py @@ -15,6 +15,7 @@ from .model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig from .model.utils.features_utils import FeaturesUtils from .utils.download_utils import download_model_if_needed +from shared.utils import files_locator as fl log = logging.getLogger() @@ -26,7 +27,7 @@ class ModelConfig: vae_path: Path bigvgan_16k_path: Optional[Path] mode: str - synchformer_ckpt: Path = Path('ckpts/mmaudio/synchformer_state_dict.pth') + synchformer_ckpt: Path = Path( fl.locate_file('mmaudio/synchformer_state_dict.pth')) @property def seq_cfg(self) -> SequenceConfig: @@ -64,8 +65,8 @@ def download_if_needed(self): bigvgan_16k_path=None, mode='44k') large_44k_v2 = ModelConfig(model_name='large_44k_v2', - model_path=Path('ckpts/mmaudio/mmaudio_large_44k_v2.pth'), - vae_path=Path('ckpts/mmaudio/v1-44.pth'), + model_path=Path( fl.locate_file('mmaudio/mmaudio_large_44k_v2.pth')), + vae_path=Path(fl.locate_file('mmaudio/v1-44.pth')), bigvgan_16k_path=None, mode='44k') all_model_cfg: dict[str, ModelConfig] = { diff --git a/postprocessing/mmaudio/model/utils/features_utils.py b/postprocessing/mmaudio/model/utils/features_utils.py index 2947c303d..285e4e224 100644 --- a/postprocessing/mmaudio/model/utils/features_utils.py +++ b/postprocessing/mmaudio/model/utils/features_utils.py @@ -12,6 +12,7 @@ from ...ext.mel_converter import get_mel_converter from ...ext.synchformer.synchformer import Synchformer from ...model.utils.distributions import DiagonalGaussianDistribution +from shared.utils import files_locator as fl def patch_clip(clip_model): @@ -31,7 +32,7 @@ def new_encode_text(self, text, normalize: bool = False): return clip_model def get_model_config(model_name): - with open("ckpts/DFN5B-CLIP-ViT-H-14-378/open_clip_config.json", 'r', encoding='utf-8') as f: + with open( fl.locate_file("DFN5B-CLIP-ViT-H-14-378/open_clip_config.json"), 'r', encoding='utf-8') as f: return json.load(f)["model_cfg"] class FeaturesUtils(nn.Module): @@ -51,10 +52,10 @@ def __init__( if enable_conditions: old_get_model_config = open_clip.factory.get_model_config open_clip.factory.get_model_config = get_model_config - with open("ckpts/DFN5B-CLIP-ViT-H-14-378/open_clip_config.json", 'r', encoding='utf-8') as f: + with open( fl.locate_file("DFN5B-CLIP-ViT-H-14-378/open_clip_config.json"), 'r', encoding='utf-8') as f: override_preprocess = json.load(f)["preprocess_cfg"] - self.clip_model = create_model('DFN5B-CLIP-ViT-H-14-378', pretrained='ckpts/DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin', force_preprocess_cfg= override_preprocess) + self.clip_model = create_model('DFN5B-CLIP-ViT-H-14-378', pretrained= fl.locate_file('DFN5B-CLIP-ViT-H-14-378/open_clip_pytorch_model.bin'), force_preprocess_cfg= override_preprocess) open_clip.factory.get_model_config = old_get_model_config # self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False) diff --git a/preprocessing/face_preprocessor.py b/preprocessing/face_preprocessor.py index bdef48b78..8c86a7870 100644 --- a/preprocessing/face_preprocessor.py +++ b/preprocessing/face_preprocessor.py @@ -10,6 +10,7 @@ from torchvision.transforms.functional import normalize from typing import Union, Optional from models.hyvideo.data_kits.face_align import AlignImage +from shared.utils import files_locator as fl def _img2tensor(img: np.ndarray, bgr2rgb: bool = True) -> torch.Tensor: @@ -53,7 +54,7 @@ def _pad_to_square(img: np.ndarray, pad_color: int = 255) -> np.ndarray: class FaceProcessor: def __init__(self): - self.align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt") + self.align_instance = AlignImage("cuda", det_path= fl.locate_file("det_align/detface.pt")) self.align_instance.facedet.model.to("cpu") diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index a5d557099..443934c30 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -22,6 +22,7 @@ from .matanyone_wrapper import matanyone from shared.utils.audio_video import save_video, save_image from mmgp import offload +from shared.utils import files_locator as fl arg_device = "cuda" arg_sam_model_type="vit_h" @@ -698,7 +699,7 @@ def load_unload_models(selected): model_in_GPU = True from .matanyone.model.matanyone import MatAnyone # matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") - matanyone_model = MatAnyone.from_pretrained("ckpts/mask") + matanyone_model = MatAnyone.from_pretrained(fl.locate_folder("mask")) # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model } # offload.profile(pipe) matanyone_model = matanyone_model.to("cpu").eval() diff --git a/preprocessing/matanyone/tools/base_segmenter.py b/preprocessing/matanyone/tools/base_segmenter.py index 096038e47..64f702d16 100644 --- a/preprocessing/matanyone/tools/base_segmenter.py +++ b/preprocessing/matanyone/tools/base_segmenter.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import PIL from .mask_painter import mask_painter +from shared.utils import files_locator as fl class BaseSegmenter: @@ -32,7 +33,7 @@ def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): # self.model.to(torch.float16) # offload.save_model(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors") - offload.load_model_data(self.model, "ckpts/mask/sam_vit_h_4b8939_fp16.safetensors") + offload.load_model_data(self.model, fl.locate_file("mask/sam_vit_h_4b8939_fp16.safetensors")) self.model.to(torch.float32) # need to be optimized, if not f32 crappy precision self.model.to(device=self.device) self.predictor = SamPredictor(self.model) diff --git a/wgp.py b/wgp.py index 49d91d788..ff0a51ddf 100644 --- a/wgp.py +++ b/wgp.py @@ -29,7 +29,8 @@ from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture from shared.attention import get_attention_modes, get_supported_attention_modes -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import hf_hub_download, snapshot_download +from shared.utils import files_locator as fl import torch import gc import traceback @@ -63,7 +64,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.2" -WanGP_version = "8.993" +WanGP_version = "8.994" settings_version = 2.39 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -563,7 +564,7 @@ def ret(): if test_any_sliding_window(model_type) and image_mode == 0: if video_length > sliding_window_size: - if model_type in ["t2v"] and not "G" in video_prompt_type : + if model_type in ["t2v", "t2v_2_2"] and not "G" in video_prompt_type : gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.") return ret() full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1 @@ -1793,16 +1794,17 @@ def get_lora_dir(model_type): if not os.path.isfile(server_config_filename) and os.path.isfile("gradio_config.json"): shutil.move("gradio_config.json", server_config_filename) -if not os.path.isdir("ckpts/umt5-xxl/"): - os.makedirs("ckpts/umt5-xxl/") -src_move = [ "ckpts/models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-bf16.safetensors", "ckpts/models_t5_umt5-xxl-enc-quanto_int8.safetensors" ] -tgt_move = [ "ckpts/xlm-roberta-large/", "ckpts/umt5-xxl/", "ckpts/umt5-xxl/"] +src_move = [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "models_t5_umt5-xxl-enc-bf16.safetensors", "models_t5_umt5-xxl-enc-quanto_int8.safetensors" ] +tgt_move = [ "xlm-roberta-large", "umt5-xxl", "umt5-xxl"] for src,tgt in zip(src_move,tgt_move): - if os.path.isfile(src): + src = fl.locate_file(src) + tgt = fl.get_download_location(tgt) + if src is not None: try: if os.path.isfile(tgt): shutil.remove(src) else: + os.makedirs(os.path.dirname(tgt)) shutil.move(src, tgt) except: pass @@ -1823,7 +1825,8 @@ def get_lora_dir(model_type): "vae_config": 0, "profile" : profile_type.LowRAM_LowVRAM, "preload_model_policy": [], - "UI_theme": "default" + "UI_theme": "default", + "checkpoints_paths": fl.default_checkpoints_paths, } with open(server_config_filename, "w", encoding="utf-8") as writer: @@ -1833,6 +1836,10 @@ def get_lora_dir(model_type): text = reader.read() server_config = json.loads(text) +checkpoints_paths = server_config.get("checkpoints_paths", None) +if checkpoints_paths is None: checkpoints_paths = server_config["checkpoints_paths"] = fl.default_checkpoints_paths +fl.set_checkpoints_paths(checkpoints_paths) + # Deprecated models for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_int8.safetensors", "sky_reels2_diffusion_forcing_720p_14B_quanto_fp16_int8.safetensors", "wan2.1_image2video_480p_14B_bf16.safetensors", "wan2.1_image2video_480p_14B_quanto_int8.safetensors", @@ -1841,11 +1848,11 @@ def get_lora_dir(model_type): "wan2.1_Vace_14B_mbf16.safetensors", "wan2.1_Vace_14B_quanto_mbf16_int8.safetensors", "wan2.1_FLF2V_720p_14B_quanto_int8.safetensors", "wan2.1_FLF2V_720p_14B_bf16.safetensors", "wan2.1_FLF2V_720p_14B_fp16.safetensors", "wan2.1_Vace_1.3B_mbf16.safetensors", "wan2.1_text2video_1.3B_bf16.safetensors", "ltxv_0.9.7_13B_dev_bf16.safetensors" ]: - if Path(os.path.join("ckpts" , path)).is_file(): + if fl.locate_file(path) is not None: print(f"Removing old version of model '{path}'. A new version of this model will be downloaded next time you use it.") - os.remove( os.path.join("ckpts" , path)) + os.remove( fl.locate_file(path)) -for f, s in [("ckpts/Florence2/modeling_florence2.py", 127287)]: +for f, s in [(fl.locate_file("Florence2/modeling_florence2.py"), 127287)]: try: if os.path.isfile(f) and os.path.getsize(f) == s: print(f"Removing old version of model '{f}'. A new version of this model will be downloaded next time you use it.") @@ -2050,7 +2057,6 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}") return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs]) - # choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] choices = URLs if len(quantization) == 0: quantization = "bf16" @@ -2259,7 +2265,7 @@ def init_model_def(model_type, model_def): base_model_type = get_base_model_type(model_type) family_handler = model_types_handlers.get(base_model_type, None) if family_handler is None: - raise Exception(f"Unknown model type {model_type}") + raise Exception(f"Unknown model type {base_model_type}") default_model_def = family_handler.query_model_def(base_model_type, model_def) if default_model_def is None: return model_def default_model_def.update(model_def) @@ -2418,13 +2424,13 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod saved_finetune_def = json.load(reader) update_model_def = False - model_filename = os.path.join("ckpts",model_filename) + model_filename_path = os.path.join(fl.get_download_folder(), model_filename) quanto_dtypestr= "bf16" if dtype == torch.bfloat16 else "fp16" if ("m" + dtypestr) in model_filename: dtypestr = "m" + dtypestr quanto_dtypestr = "m" + quanto_dtypestr - if not os.path.isfile(model_filename) and (not no_fp16_main_model or dtype == torch.bfloat16): - offload.save_model(model, model_filename, config_file_path=config_file, filter_sd=filter) + if fl.locate_file(model_filename) is None and (not no_fp16_main_model or dtype == torch.bfloat16): + offload.save_model(model, model_filename_path, config_file_path=config_file, filter_sd=filter) print(f"New model file '{model_filename}' had been created for finetune Id '{model_type}'.") del saved_finetune_def["model"][source_key] del model_def[source_key] @@ -2433,10 +2439,11 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod if is_module: quanto_filename = model_filename.replace(dtypestr, "quanto_" + quanto_dtypestr + "_int8" ) + quanto_filename_path = os.path.join(fl.get_download_folder() , quanto_filename) if hasattr(model, "_quanto_map"): print("unable to generate quantized module, the main model should at full 16 bits before quantization can be done") - elif not os.path.isfile(quanto_filename): - offload.save_model(model, quanto_filename, config_file_path=config_file, do_quantize= True, filter_sd=filter) + elif fl.locate_file(quanto_filename) is None: + offload.save_model(model, quanto_filename_path, config_file_path=config_file, do_quantize= True, filter_sd=filter) print(f"New quantized file '{quanto_filename}' had been created for finetune Id '{model_type}'.") if isinstance(model_def[url_key][0],dict): model_def[url_key][0][url_dict_key].append(quanto_filename) @@ -2473,10 +2480,11 @@ def save_quantized_model(model, model_type, model_filename, dtype, config_file, pos = model_filename.rfind(".") model_filename = model_filename[:pos] + "_quanto_int8" + model_filename[pos+1:] - if os.path.isfile(model_filename): + if fl.locate_file(model_filename) is not None: print(f"There isn't any model to quantize as quantized model '{model_filename}' aready exists") else: - offload.save_model(model, model_filename, do_quantize= True, config_file_path=config_file) + model_filename_path = os.path.join(fl.get_download_folder(), model_filename) + offload.save_model(model, model_filename_path, do_quantize= True, config_file_path=config_file) print(f"New quantized file '{model_filename}' had been created for finetune Id '{model_type}'.") if not model_filename in URLs: URLs.append(model_filename) @@ -2500,26 +2508,27 @@ def preprocessor_wrapper(sd): def get_local_model_filename(model_filename): if model_filename.startswith("http"): - local_model_filename = os.path.join("ckpts", os.path.basename(model_filename)) + local_model_filename =os.path.basename(model_filename) else: local_model_filename = model_filename + local_model_filename = fl.locate_file(local_model_filename) return local_model_filename def process_files_def(repoId, sourceFolderList, fileList): - targetRoot = "ckpts/" + targetRoot = fl.get_download_location() for sourceFolder, files in zip(sourceFolderList,fileList ): if len(files)==0: - if not Path(targetRoot + sourceFolder).exists(): + if fl.locate_folder(sourceFolder) is None: snapshot_download(repo_id=repoId, allow_patterns=sourceFolder +"/*", local_dir= targetRoot) else: for onefile in files: if len(sourceFolder) > 0: - if not os.path.isfile(targetRoot + sourceFolder + "/" + onefile ): + if fl.locate_file(sourceFolder + "/" + onefile) is None: hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot, subfolder=sourceFolder) else: - if not os.path.isfile(targetRoot + onefile ): + if fl.locate_file(onefile) is None: hf_hub_download(repo_id=repoId, filename=onefile, local_dir = targetRoot) def download_mmaudio(): @@ -2541,15 +2550,18 @@ def download_file(url,filename): onefile = os.path.basename(url_parts[-1]) sourceFolder = os.path.dirname(url_parts[-1]) if len(sourceFolder) == 0: - hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/" if len(base_dir)==0 else base_dir) + hf_hub_download(repo_id=repoId, filename=onefile, local_dir = fl.get_download_location() if len(base_dir)==0 else base_dir) else: - target_path = "ckpts/temp/" + sourceFolder + temp_dir_path = os.path.join(fl.get_download_location(), "temp") + target_path = os.path.join(temp_dir_path, sourceFolder) if not os.path.exists(target_path): os.makedirs(target_path) - hf_hub_download(repo_id=repoId, filename=onefile, local_dir = "ckpts/temp/", subfolder=sourceFolder) - shutil.move(os.path.join( "ckpts", "temp" , sourceFolder , onefile), "ckpts/" if len(base_dir)==0 else base_dir) - shutil.rmtree("ckpts/temp") + hf_hub_download(repo_id=repoId, filename=onefile, local_dir = temp_dir_path, subfolder=sourceFolder) + shutil.move(os.path.join( target_path, onefile), fl.get_download_location() if len(base_dir)==0 else base_dir) + shutil.rmtree(temp_dir_path) else: + from urllib.request import urlretrieve + from shared.utils.download import create_progress_hook urlretrieve(url,filename, create_progress_hook(filename)) download_shared_done = False @@ -2563,8 +2575,6 @@ def computeList(filename): - from urllib.request import urlretrieve - from shared.utils.download import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", @@ -2607,12 +2617,13 @@ def computeList(filename): any_source = ("source2" if submodel_no ==2 else "source") in model_def any_module_source = ("module_source2" if submodel_no ==2 else "module_source") in model_def model_type_handler = model_types_handlers[base_model_type] - local_model_filename = get_local_model_filename(model_filename) - + if any_source and not module_type or any_module_source and module_type: model_filename = None else: - if not os.path.isfile(local_model_filename): + local_model_filename = get_local_model_filename(model_filename) + if local_model_filename is None: + local_model_filename = fl.get_download_location(os.path.basename(model_filename)) url = model_filename if not url.startswith("http"): @@ -2621,13 +2632,13 @@ def computeList(filename): download_file(url, local_model_filename) except Exception as e: if os.path.isfile(local_model_filename): os.remove(local_model_filename) - raise Exception(f"'{url}' is invalid for Model '{local_model_filename}' : {str(e)}'") + raise Exception(f"'{url}' is invalid for Model '{model_type}' : {str(e)}'") if module_type: return model_filename = None preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) for url in preload_URLs: - filename = "ckpts/" + url.split("/")[-1] + filename = fl.get_download_location(url.split("/")[-1]) if not os.path.isfile(filename ): if not url.startswith("http"): raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") @@ -2809,8 +2820,8 @@ def setup_prompt_enhancer(pipe, kwargs): model_no = server_config.get("enhancer_enabled", 0) if model_no != 0: from transformers import ( AutoModelForCausalLM, AutoProcessor, AutoTokenizer, LlamaForCausalLM ) - prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained( "ckpts/Florence2", trust_remote_code=True) - prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained( "ckpts/Florence2", trust_remote_code=True) + prompt_enhancer_image_caption_model = AutoModelForCausalLM.from_pretrained(fl.locate_folder("Florence2"), trust_remote_code=True) + prompt_enhancer_image_caption_processor = AutoProcessor.from_pretrained(fl.locate_folder("Florence2"), trust_remote_code=True) pipe["prompt_enhancer_image_caption_model"] = prompt_enhancer_image_caption_model prompt_enhancer_image_caption_model._model_dtype = torch.float # def preprocess_sd(sd, map): @@ -2825,12 +2836,12 @@ def setup_prompt_enhancer(pipe, kwargs): if model_no == 1: budget = 5000 - prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/Llama3_2/Llama3_2_quanto_bf16_int8.safetensors") - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/Llama3_2") + prompt_enhancer_llm_model = offload.fast_load_transformers_model( fl.locate_file("Llama3_2/Llama3_2_quanto_bf16_int8.safetensors")) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder("Llama3_2")) else: budget = 10000 - prompt_enhancer_llm_model = offload.fast_load_transformers_model("ckpts/llama-joycaption-beta-one-hf-llava/llama_joycaption_quanto_bf16_int8.safetensors") - prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained("ckpts/llama-joycaption-beta-one-hf-llava") + prompt_enhancer_llm_model = offload.fast_load_transformers_model(fl.locate_file("llama-joycaption-beta-one-hf-llava/llama_joycaption_quanto_bf16_int8.safetensors")) + prompt_enhancer_llm_tokenizer = AutoTokenizer.from_pretrained(fl.locate_folder("llama-joycaption-beta-one-hf-llava")) pipe["prompt_enhancer_llm_model"] = prompt_enhancer_llm_model if not "budgets" in kwargs: kwargs["budgets"] = {} kwargs["budgets"]["prompt_enhancer_llm_model"] = budget @@ -2875,7 +2886,7 @@ def load_models(model_type, override_profile = -1): if model_filename2 != None: model_file_list += [model_filename2] model_type_list += [model_type] - module_type_list += [False] + module_type_list += [None] model_submodel_no_list += [2] for module_type in modules: if isinstance(module_type,dict): @@ -3012,7 +3023,11 @@ def apply_changes( state, image_output_codec_choice = None, audio_output_codec_choice = None, last_resolution_choice = None, + checkpoints_paths = "", ): + checkpoints_paths = fl.default_checkpoints_paths if len(checkpoints_paths.strip()) == "" else checkpoints_paths.replace("\r", "").split("\n") + checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ] + fl.set_checkpoints_paths(checkpoints_paths) if args.lock_config: return "
Config Locked
",*[gr.update()]*4 if gen_in_progress: @@ -3054,6 +3069,7 @@ def apply_changes( state, "last_advanced_choice": state["advanced"], "last_resolution_choice": last_resolution_choice, "last_resolution_per_group": state["last_resolution_per_group"], + "checkpoints_paths": checkpoints_paths, } if Path(server_config_filename).is_file(): @@ -3087,14 +3103,14 @@ def apply_changes( state, transformer_quantization = server_config["transformer_quantization"] transformer_dtype_policy = server_config["transformer_dtype_policy"] text_encoder_quantization = server_config["text_encoder_quantization"] - transformer_types = server_config["transformer_types"] + transformer_types = server_config["transformer_types"] model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename if "enhancer_enabled" in changes or "enhancer_mode" in changes: reset_prompt_enhancer() if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats", - "video_output_codec", "image_output_codec", "audio_output_codec"] for change in changes ): + "video_output_codec", "image_output_codec", "audio_output_codec", "checkpoints_paths"] for change in changes ): model_family = gr.Dropdown() model_choice = gr.Dropdown() else: @@ -3649,28 +3665,23 @@ def get_preprocessor(process_type, inpaint_color): if process_type=="pose": from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator cfg_dict = { - "DETECTION_MODEL": "ckpts/pose/yolox_l.onnx", - "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx", + "DETECTION_MODEL": fl.locate_file("pose/yolox_l.onnx"), + "POSE_MODEL": fl.locate_file("pose/dw-ll_ucoco_384.onnx"), "RESIZE_SIZE": 1024 } anno_ins = lambda img: PoseBodyFaceVideoAnnotator(cfg_dict).forward(img) elif process_type=="depth": - # from preprocessing.midas.depth import DepthVideoAnnotator - # cfg_dict = { - # "PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt" - # } - # anno_ins = lambda img: DepthVideoAnnotator(cfg_dict).forward(img)[0] from preprocessing.depth_anything_v2.depth import DepthV2VideoAnnotator if server_config.get("depth_anything_v2_variant", "vitl") == "vitl": cfg_dict = { - "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitl.pth", + "PRETRAINED_MODEL": fl.locate_file("depth/depth_anything_v2_vitl.pth"), 'MODEL_VARIANT': 'vitl' } else: cfg_dict = { - "PRETRAINED_MODEL": "ckpts/depth/depth_anything_v2_vitb.pth", + "PRETRAINED_MODEL": fl.locate_file("depth/depth_anything_v2_vitb.pth"), 'MODEL_VARIANT': 'vitb', } @@ -3682,19 +3693,19 @@ def get_preprocessor(process_type, inpaint_color): elif process_type=="canny": from preprocessing.canny import CannyVideoAnnotator cfg_dict = { - "PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth" + "PRETRAINED_MODEL": fl.locate_file("scribble/netG_A_latest.pth") } anno_ins = lambda img: CannyVideoAnnotator(cfg_dict).forward(img) elif process_type=="scribble": from preprocessing.scribble import ScribbleVideoAnnotator cfg_dict = { - "PRETRAINED_MODEL": "ckpts/scribble/netG_A_latest.pth" + "PRETRAINED_MODEL": fl.locate_file("scribble/netG_A_latest.pth") } anno_ins = lambda img: ScribbleVideoAnnotator(cfg_dict).forward(img) elif process_type=="flow": from preprocessing.flow import FlowVisAnnotator cfg_dict = { - "PRETRAINED_MODEL": "ckpts/flow/raft-things.pth" + "PRETRAINED_MODEL": fl.locate_file("flow/raft-things.pth") } anno_ins = lambda img: FlowVisAnnotator(cfg_dict).forward(img) elif process_type=="inpaint": @@ -4092,10 +4103,10 @@ def perform_temporal_upsampling(sample, previous_last_frame, temporal_upsampling if previous_last_frame != None: sample = torch.cat([previous_last_frame, sample], dim=1) previous_last_frame = sample[:, -1:].clone() - sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device) + sample = temporal_interpolation( fl.locate_file("flownet.pkl"), sample, exp, device=processing_device) sample = sample[:, 1:] else: - sample = temporal_interpolation( os.path.join("ckpts", "flownet.pkl"), sample, exp, device=processing_device) + sample = temporal_interpolation( fl.locate_file("flownet.pkl"), sample, exp, device=processing_device) previous_last_frame = sample[:, -1:].clone() output_fps = output_fps * 2**exp @@ -5838,13 +5849,12 @@ def get_new_preset_msg(advanced = True): return "Choose a Lora Preset or a Settings file in this List" def compute_lset_choices(model_type, loras_presets): - # top_dir = "settings" global_list = [] if model_type is not None: - top_dir = get_lora_dir(model_type) + top_dir = "profiles" # get_lora_dir(model_type) model_def = get_model_def(model_type) settings_dir = get_model_recursive_prop(model_type, "profiles_dir", return_list=False) - if settings_dir is None or len(settings_dir) == 0: settings_dir = ["profiles_" + get_model_family(model_type, True)] + if settings_dir is None or len(settings_dir) == 0: settings_dir = [""] for dir in settings_dir: if len(dir) == "": continue cur_path = os.path.join(top_dir, dir) @@ -6081,8 +6091,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, get_unique_id(), gr.update(), gr.update(), gr.update() else: - # lset_path = os.path.join("settings", lset_name) if len(Path(lset_name).parts)>1 else os.path.join(get_lora_dir(current_model_type), lset_name) - lset_path = os.path.join(get_lora_dir(current_model_type), lset_name) + lset_path = os.path.join("profiles" if len(Path(lset_name).parts)>1 else get_lora_dir(current_model_type), lset_name) configs, _ = get_settings_from_file(state,lset_path , True, True, True, min_settings_version=2.38, merge_loras = "merge after" if len(Path(lset_name).parts)<=1 else "merge before" ) if configs == None: @@ -6275,7 +6284,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None inputs["base_model_type"] = base_model_type diffusion_forcing = base_model_type in ["sky_df_1.3B", "sky_df_14B"] vace = test_vace_module(base_model_type) - t2v= base_model_type in ["t2v"] + t2v= base_model_type in ["t2v", "t2v_2_2"] ltxv = base_model_type in ["ltxv_13B"] recammaster = base_model_type in ["recam_1.3B"] phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] @@ -6309,10 +6318,8 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if model_def.get("model_modes", None) is None: pop += ["model_mode"] - if not vace and not phantom and not hunyuan_video_custom: - unsaved_params = ["keep_frames_video_guide", "mask_expand"] #"video_prompt_type", - if base_model_type in ["t2v"]: unsaved_params = unsaved_params[1:] - pop += unsaved_params + if model_def.get("guide_custom_choices", None ) is None and model_def.get("guide_preprocessing", None ) is None: + pop += ["keep_frames_video_guide", "mask_expand"] if not "I" in video_prompt_type: pop += ["remove_background_images_ref"] @@ -7545,7 +7552,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non if len(launch_loras) == 0: launch_multis_str = ui_defaults.get("loras_multipliers","") launch_loras = [os.path.basename(path) for path in ui_defaults.get("activated_loras",[])] - with gr.Row(): with gr.Column(): with gr.Column(visible=False, elem_id="image-modal-container") as modal_container: @@ -7584,7 +7590,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non cancel_lset_btn = gr.Button("Don't do it !", size="sm", min_width= 1 , visible=False) #confirm_save_lset_btn, confirm_delete_lset_btn, save_lset_btn, delete_lset_btn, cancel_lset_btn trigger_refresh_input_type = gr.Text(interactive= False, visible= False) - t2v = base_model_type in ["t2v"] + t2v = base_model_type in ["t2v", "t2v_2_2"] t2v_1_3B = base_model_type in ["t2v_1.3B"] flf2v = base_model_type == "flf2v_720p" base_model_family = get_model_family(base_model_type) @@ -9017,6 +9023,8 @@ def check(mode): value=server_config.get("max_frames_multiplier", 1), label="Increase the Max Number of Frames (needs more RAM and VRAM, usually the longer the worse the quality, needs an App restart)" ) + checkpoints_paths_text = "\n".join(server_config.get("checkpoints_paths", fl.default_checkpoints_paths)) + checkpoints_paths = gr.Text(label ="Folders where Model Checkpoints are Stored. One Path per Line. First Path is Default Download Path.", value=checkpoints_paths_text,lines=3) UI_theme_choice = gr.Dropdown( choices=[ @@ -9263,6 +9271,7 @@ def check(mode): image_output_codec_choice, audio_output_codec_choice, resolution, + checkpoints_paths, ], outputs= [msg , header, model_family, model_choice, refresh_form_trigger] ) From 7054566334b3c12d9910da728796ce3157dc68b5 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 6 Oct 2025 20:45:19 +0200 Subject: [PATCH 075/155] removed pyc files --- .../configs/__pycache__/__init__.cpython-310.pyc | Bin 950 -> 0 bytes .../__pycache__/shared_config.cpython-310.pyc | Bin 821 -> 0 bytes .../__pycache__/wan_i2v_14B.cpython-310.pyc | Bin 969 -> 0 bytes .../__pycache__/wan_t2v_14B.cpython-310.pyc | Bin 718 -> 0 bytes .../__pycache__/wan_t2v_1_3B.cpython-310.pyc | Bin 726 -> 0 bytes 5 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 models/wan/configs/__pycache__/__init__.cpython-310.pyc delete mode 100644 models/wan/configs/__pycache__/shared_config.cpython-310.pyc delete mode 100644 models/wan/configs/__pycache__/wan_i2v_14B.cpython-310.pyc delete mode 100644 models/wan/configs/__pycache__/wan_t2v_14B.cpython-310.pyc delete mode 100644 models/wan/configs/__pycache__/wan_t2v_1_3B.cpython-310.pyc diff --git a/models/wan/configs/__pycache__/__init__.cpython-310.pyc b/models/wan/configs/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 5fdfaf60617346093216ab70837ec41ba57162a2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 950 zcmZ8f&rjPh6t>+oantk%Eg+6u+X+NNr6GhS0_#wkwb`1YKJb2z{H(8_W@{ZGWTCGN_>F!T3h$=WX zJo}v2>N^1Gse*i6VI#ryopE`06dZf6HrH}2kYR%qQn^I+~ z;7*M?fQ@x~YJjPNGtSi3?0Quh4;AysLrXfUXkktj%_~z-rl?Aidsw@}ciM60yg!i1 zX(|h`HW-~;#AFo3h78WVa1^lY`+mT08WzH-5VI6>+qY0Xo|)*TeW3wNOofyusv1+R zQQeLy=RtocIcM1;r_FZ{+RvJMF0)*?sDEB@-+9HIZo|1{;`Hg+i7#XAoxF%x!F|!` z9CQyGPq|8S+1yKKw~^r61h$Y+|JPyjGv&klPbgf1?IvdWNN)D_TRpgSd#}?Qj)e51 z9vrqeNZv!#ZcS>t^x0~HR^S(DblNlb4A{bl0f22N# z5&FrB6+v-1ZL^Tx-d*W*T1)(mpJE@7n-e zdy`Z;`D-VD*~Yb#px$6PP->R}uX^a(6|%DGPwg|hSs>qA^spJstpIiH@Cj0Dj1W~b?B9Rvb%avZ2qs) z8LOvPEn+u%mG`Dd>!?@RZon__&sXVm1G;DK9V5E6MBdB;?H-TWEi_@{#j(Bmw0rag z6k1&LYI~6BS|zW(O=r&lcSZO3n4HeT5kMh@w|>1qzwM6r3G-Zp<6(|)QSL7d C<&vZT diff --git a/models/wan/configs/__pycache__/wan_i2v_14B.cpython-310.pyc b/models/wan/configs/__pycache__/wan_i2v_14B.cpython-310.pyc deleted file mode 100644 index e8b61e52cf614e0fa5157edef10f1cf3f8332edd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 969 zcmZXS%W4!s6o$LH=hEr9<*v!bES*J6Gr>R#`m1yPK6R?=uTC2c$H4L7*Y}`OGK?QOxj8IM z9^oTDA;16<8J?cah{UF65(6wKMOIw$O2ptMl%aA-yfQCeRe0$m!S_{Ord6xdxhbs399rM>e5>Y(W#b2rc9iEFza-3AqBx$kl>tu!8YAtRn9eybEg>Zvc5? zwAZKjt;VZ1$?YR1$4^6FO$jO}G|yQ=Wxxan>h}kel_zP^4+r~O?^v?+eEWE-yK|`5 zQJS!rQ#!d9r@$jg)h->yYS+EEh+Lld?n%G9*U3~cv9D!s!;iv@UPLh!=?NE#(X&v| zfOdB#$Edn<)s2`K@CWn0#hNtZiK`>`R-KDI++Z}|Zg*$4sbqbwcp_6FCrh{LCrv!b zcBe~^A0O%OYg<#iR%2=-L{*z$34NziQwv%Rja7|3)C_cKdS4fr)4w=G4db_i%b#Db zet-LN(B90;Diwa<*&KZZ^i}9KUgO%)!c3tz%&Xz{86N#GuZ%KaihB*T&=0sj%+fGX zUL8}F4tWwzxbU39hr*M*+K*DEx_fgG)Hvf_t0=g!zF7$WJ2cN2zuDi>xl|$qo;w*+ zen2H|m0J+TxJ5sqIwQCXk-b)>co4G)YCAw|tRG zNtsf~Ot5$+KO15bPhQ|phwt}ZzUb-8c=ozCyX+qRd_BMLNa~o*E$U?B+69E&F)O5E*#wF6{F!BQI^-WbA|X=% diff --git a/models/wan/configs/__pycache__/wan_t2v_14B.cpython-310.pyc b/models/wan/configs/__pycache__/wan_t2v_14B.cpython-310.pyc deleted file mode 100644 index 97ace1355070f36e8bcb71a0acf1c06e4ec947af..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 718 zcmYk3%Wl&^6ozMFH+GyjX`8Mg3l?ONNQr0{6{@&ZYycrq6)B1|GCt$jif^59nkH+u zeFapO?AY)MyaLi}S$V+@2?-&=8HYm0@~8itxt;$^=DH5S@$UPV@K=qHuR6Fn3=AIN zFW(}70F6mPCzKMXff3gdV`5OkDVVTzNhc;ZueP}HUeDLL$?9MsEwGU`G>{E&kPbAF zO=uxo;38dUBij`_&_TTo+sK`YyRd`03%jpKzguD@q)(N7?8()W$S+E|Nwk+<%4FyX z4$Kc`>+a($4Wij$@3oiqPR6Ib(Y;4{A7iPPaK_fRlMHw)S+UO+NwGh;xQGWl^#^Z) z(ZMh;!nJ*!Ubpd#@o>aWA0F$OzEPTZX=x*brLQ-HZdHjrBtmB+vZrx?IfC_-clDxL z`5T95qCS~je*1LwaSQ9N@N>uyr_Y~F zbrYYyoNn4a#pjy(l|!jD`)hERuatcs*DPXwq@PyaL9}SA=9mtx8}^tYycWU1h&&Z<;1& zu6+emj@-EL3cLc6ublS44GD=0j1vVqmVX-m$M($Fnd{mF?ZeM+p=%KGT`%q$YUD9a z`5plTXiO41qLe@djJT2*BZCr7!Gx_FIx@L=yTy%>zhm+=X5A_h9c0>9$L(gmkI0&OEtz7WqX<*FtmdrA&sN z;K2M~vTQ%i(jb}~ciwtw=X`k48T1cM^+Co`FX4TF`1lsv7rEBUoVR=;%E} zw^IH=6IwW*-2C`*`|IoHlULogvN(1G*pbk;D0kD#stuv@DXSP>u>m_cRn=JzUcr@% z#rq-mr+F5og>rCNWK*6-ODt`TA__a>~+7B&y|$OiJbzMY|^08{C>!f z$1h)u^;157JzhWl7=LZ-uUeGaXPXZD`9fKbaMLX2C;E%!1B64Jif!7oYE=JEivGW0 F)4v0&&vF0& From cdff82d56a71bb42d662a5f1e9e47cc33ebe14ed Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 6 Oct 2025 20:46:46 +0200 Subject: [PATCH 076/155] add missing file locator --- shared/utils/files_locator.py | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 shared/utils/files_locator.py diff --git a/shared/utils/files_locator.py b/shared/utils/files_locator.py new file mode 100644 index 000000000..1e425d092 --- /dev/null +++ b/shared/utils/files_locator.py @@ -0,0 +1,46 @@ +from __future__ import annotations +import os +from pathlib import Path +from typing import Iterable, List, Optional, Union + +default_checkpoints_paths = ["ckpts", "."] + +_checkpoints_paths = default_checkpoints_paths + +def set_checkpoints_paths(checkpoints_paths): + global _checkpoints_paths + _checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ] + if len(checkpoints_paths) == 0: + _checkpoints_paths = default_checkpoints_paths +def get_download_location(file_name = None): + if file_name is not None and os.path.isabs(file_name): return file_name + if file_name is not None: + return os.path.join(_checkpoints_paths[0], file_name) + else: + return _checkpoints_paths[0] + +def locate_folder(folder_name): + if os.path.isabs(folder_name): + if os.path.isdir(folder_name): return folder_name + else: + for folder in _checkpoints_paths: + path = os.path.join(folder, folder_name) + if os.path.isdir(path): + return path + + return None + + +def locate_file(file_name, create_path_if_none = False): + if os.path.isabs(file_name): + if os.path.isfile(file_name): return file_name + else: + for folder in _checkpoints_paths: + path = os.path.join(folder, file_name) + if os.path.isfile(path): + return path + + if create_path_if_none: + return get_download_location(file_name) + return None + From eb896266188becc84360f7f201fc68996331161f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 6 Oct 2025 20:49:50 +0200 Subject: [PATCH 077/155] restored profiles --- .../qwen}/Lightning v1.0 - 4 Steps.json | 0 .../qwen}/Lightning v1.1 - 8 Steps.json | 0 .../wan/Text2Video Cocktail - 10 Steps.json | 13 +++++++++++++ .../wan/Text2Video FusioniX - 10 Steps.json | 11 +++++++++++ ...tx2v Cfg Step Distill Rank32 - 4 Steps.json | 11 +++++++++++ ...Fun Inp HPS2 Reward 2 Phases - 4 Steps.json | 14 ++++++++++++++ .../Fun Inp MPS Reward 2 Phases - 4 Steps.json | 14 ++++++++++++++ .../Lightning v1.0 2 Phases - 4 Steps.json | 14 ++++++++++++++ .../Lightning v1.0 3 Phases - 8 Steps.json | 18 ++++++++++++++++++ .../Lightning v2025-09-28 - 4 Steps.json | 14 ++++++++++++++ ...ghtning v2025-09-28 2 Phases - 4 Steps.json | 14 ++++++++++++++ ...ghtning v2025-09-28 3 Phases - 8 Steps.json | 18 ++++++++++++++++++ .../Text2Video FusioniX - 10 Steps.json | 12 ++++++++++++ profiles/wan_2_2_5B/FastWan 3 Steps.json | 7 +++++++ .../Image2Video FusioniX - 10 Steps.json | 10 ++++++++++ 15 files changed, 170 insertions(+) rename {loras_qwen/profiles_qwen => profiles/qwen}/Lightning v1.0 - 4 Steps.json (100%) rename {loras_qwen/profiles_qwen => profiles/qwen}/Lightning v1.1 - 8 Steps.json (100%) create mode 100644 profiles/wan/Text2Video Cocktail - 10 Steps.json create mode 100644 profiles/wan/Text2Video FusioniX - 10 Steps.json create mode 100644 profiles/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json create mode 100644 profiles/wan_2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json create mode 100644 profiles/wan_2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json create mode 100644 profiles/wan_2_2/Lightning v1.0 2 Phases - 4 Steps.json create mode 100644 profiles/wan_2_2/Lightning v1.0 3 Phases - 8 Steps.json create mode 100644 profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json create mode 100644 profiles/wan_2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json create mode 100644 profiles/wan_2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json create mode 100644 profiles/wan_2_2/Text2Video FusioniX - 10 Steps.json create mode 100644 profiles/wan_2_2_5B/FastWan 3 Steps.json create mode 100644 profiles/wan_i2v/Image2Video FusioniX - 10 Steps.json diff --git a/loras_qwen/profiles_qwen/Lightning v1.0 - 4 Steps.json b/profiles/qwen/Lightning v1.0 - 4 Steps.json similarity index 100% rename from loras_qwen/profiles_qwen/Lightning v1.0 - 4 Steps.json rename to profiles/qwen/Lightning v1.0 - 4 Steps.json diff --git a/loras_qwen/profiles_qwen/Lightning v1.1 - 8 Steps.json b/profiles/qwen/Lightning v1.1 - 8 Steps.json similarity index 100% rename from loras_qwen/profiles_qwen/Lightning v1.1 - 8 Steps.json rename to profiles/qwen/Lightning v1.1 - 8 Steps.json diff --git a/profiles/wan/Text2Video Cocktail - 10 Steps.json b/profiles/wan/Text2Video Cocktail - 10 Steps.json new file mode 100644 index 000000000..087e49c92 --- /dev/null +++ b/profiles/wan/Text2Video Cocktail - 10 Steps.json @@ -0,0 +1,13 @@ +{ + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": "1 0.5 0.5 0.5", + "num_inference_steps": 10, + "guidance_phases": 1, + "guidance_scale": 1, + "flow_shift": 2 +} diff --git a/profiles/wan/Text2Video FusioniX - 10 Steps.json b/profiles/wan/Text2Video FusioniX - 10 Steps.json new file mode 100644 index 000000000..9fabf09eb --- /dev/null +++ b/profiles/wan/Text2Video FusioniX - 10 Steps.json @@ -0,0 +1,11 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 10, + "guidance_phases": 1, + "guidance_scale": 1, + "flow_shift": 2 +} diff --git a/profiles/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json b/profiles/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json new file mode 100644 index 000000000..8ed2a4db0 --- /dev/null +++ b/profiles/wan/Text2Video Lightx2v Cfg Step Distill Rank32 - 4 Steps.json @@ -0,0 +1,11 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_lightx2v_cfg_step_distill_lora_rank32.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 4, + "guidance_phases": 1, + "guidance_scale": 1, + "flow_shift": 2 +} \ No newline at end of file diff --git a/profiles/wan_2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json b/profiles/wan_2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json new file mode 100644 index 000000000..1b58ec73a --- /dev/null +++ b/profiles/wan_2_2/Fun Inp HPS2 Reward 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 8, + "guidance_scale": 1, + "guidance2_scale": 1, + "model_switch_phase": 1, + "switch_threshold": 876, + "guidance_phases": 2, + "flow_shift": 3, + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-high-noise-HPS2.1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-low-noise-HPS2.1.safetensors" + ], + "loras_multipliers": "1;0 0;1" +} \ No newline at end of file diff --git a/profiles/wan_2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json b/profiles/wan_2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json new file mode 100644 index 000000000..cc5742248 --- /dev/null +++ b/profiles/wan_2_2/Fun Inp MPS Reward 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "model_switch_phase": 1, + "switch_threshold": 876, + "guidance_phases": 2, + "flow_shift": 3, + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-high-noise-MPS.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Fun-A14B-InP-low-noise-MPS.safetensors" + ], + "loras_multipliers": "1;0 0;1" +} \ No newline at end of file diff --git a/profiles/wan_2_2/Lightning v1.0 2 Phases - 4 Steps.json b/profiles/wan_2_2/Lightning v1.0 2 Phases - 4 Steps.json new file mode 100644 index 000000000..51c903b63 --- /dev/null +++ b/profiles/wan_2_2/Lightning v1.0 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "model_switch_phase": 1, + "switch_threshold": 876, + "guidance_phases": 2, + "flow_shift": 3, + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" + ], + "loras_multipliers": "1;0 0;1" +} \ No newline at end of file diff --git a/profiles/wan_2_2/Lightning v1.0 3 Phases - 8 Steps.json b/profiles/wan_2_2/Lightning v1.0 3 Phases - 8 Steps.json new file mode 100644 index 000000000..a30834482 --- /dev/null +++ b/profiles/wan_2_2/Lightning v1.0 3 Phases - 8 Steps.json @@ -0,0 +1,18 @@ +{ + "num_inference_steps": 8, + "guidance_scale": 3.5, + "guidance2_scale": 1, + "guidance3_scale": 1, + "switch_threshold": 965, + "switch_threshold2": 800, + "guidance_phases": 3, + "model_switch_phase": 2, + "flow_shift": 3, + "sample_solver": "euler", + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_HIGH_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2.2-Lightning_T2V-v1.1-A14B-4steps-lora_LOW_fp16.safetensors" + ], + "loras_multipliers": "0;1;0 0;0;1", + "help": "This finetune uses the Lightning 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators." +} \ No newline at end of file diff --git a/profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json b/profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json new file mode 100644 index 000000000..31e4f0d66 --- /dev/null +++ b/profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "switch_threshold": 876, + "model_switch_phase": 1, + "guidance_phases": 2, + "flow_shift": 3, + "loras_multipliers": "1;0 0;1", + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + ] +} \ No newline at end of file diff --git a/profiles/wan_2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json b/profiles/wan_2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json new file mode 100644 index 000000000..31e4f0d66 --- /dev/null +++ b/profiles/wan_2_2/Lightning v2025-09-28 2 Phases - 4 Steps.json @@ -0,0 +1,14 @@ +{ + "num_inference_steps": 4, + "guidance_scale": 1, + "guidance2_scale": 1, + "switch_threshold": 876, + "model_switch_phase": 1, + "guidance_phases": 2, + "flow_shift": 3, + "loras_multipliers": "1;0 0;1", + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + ] +} \ No newline at end of file diff --git a/profiles/wan_2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json b/profiles/wan_2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json new file mode 100644 index 000000000..839e0564c --- /dev/null +++ b/profiles/wan_2_2/Lightning v2025-09-28 3 Phases - 8 Steps.json @@ -0,0 +1,18 @@ +{ + "num_inference_steps": 8, + "guidance_scale": 3.5, + "guidance2_scale": 1, + "guidance3_scale": 1, + "switch_threshold": 965, + "switch_threshold2": 800, + "guidance_phases": 3, + "model_switch_phase": 2, + "flow_shift": 3, + "sample_solver": "euler", + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + ], + "loras_multipliers": "0;1;0 0;0;1", + "help": "This finetune uses the Lightning 250928 4 steps Loras Accelerator for Wan 2.2 but extend them to 8 steps in order to insert a CFG phase before the 2 accelerated phases with no Guidance. The ultimate goal is reduce the slow motion effect of these Loras Accelerators." +} \ No newline at end of file diff --git a/profiles/wan_2_2/Text2Video FusioniX - 10 Steps.json b/profiles/wan_2_2/Text2Video FusioniX - 10 Steps.json new file mode 100644 index 000000000..1ad921b63 --- /dev/null +++ b/profiles/wan_2_2/Text2Video FusioniX - 10 Steps.json @@ -0,0 +1,12 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_T2V_14B_FusionX_LoRA.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance2_scale": 1, + "guidance_phases": 2, + "flow_shift": 2 +} diff --git a/profiles/wan_2_2_5B/FastWan 3 Steps.json b/profiles/wan_2_2_5B/FastWan 3 Steps.json new file mode 100644 index 000000000..1e0383c61 --- /dev/null +++ b/profiles/wan_2_2_5B/FastWan 3 Steps.json @@ -0,0 +1,7 @@ +{ + "activated_loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 3, + "loras_multipliers": "1" +} \ No newline at end of file diff --git a/profiles/wan_i2v/Image2Video FusioniX - 10 Steps.json b/profiles/wan_i2v/Image2Video FusioniX - 10 Steps.json new file mode 100644 index 000000000..1fe4c14e5 --- /dev/null +++ b/profiles/wan_i2v/Image2Video FusioniX - 10 Steps.json @@ -0,0 +1,10 @@ +{ + + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan2.1_I2V_14B_FusionX_LoRA.safetensors" + ], + "loras_multipliers": "1", + "num_inference_steps": 10, + "guidance_phases": 1, + "guidance_scale": 1 +} From 37366ba9c4405249c25bef948c3e064898c3a8c6 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Mon, 6 Oct 2025 21:57:40 +0200 Subject: [PATCH 078/155] Palingenesis back from the shelves where it was collecting dust --- README.md | 2 +- defaults/i2v_palingenesis_2_2.json | 18 +++++++++++++++ defaults/t2v_lighting_palingenesis_2_2.json | 25 +++++++++++++++++++++ defaults/t2v_palingenesis_2_2.json | 15 +++++++++++++ 4 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 defaults/i2v_palingenesis_2_2.json create mode 100644 defaults/t2v_lighting_palingenesis_2_2.json create mode 100644 defaults/t2v_palingenesis_2_2.json diff --git a/README.md b/README.md index 126e97d3c..fdaacf94f 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep ## 🔥 Latest Updates : -### October 6 2025: WanGP v8.993 - A few last things before the Big Unknown ... +### October 6 2025: WanGP v8.994 - A few last things before the Big Unknown ... This new version hasn't any new model... diff --git a/defaults/i2v_palingenesis_2_2.json b/defaults/i2v_palingenesis_2_2.json new file mode 100644 index 000000000..129ef654d --- /dev/null +++ b/defaults/i2v_palingenesis_2_2.json @@ -0,0 +1,18 @@ +{ + "model": + { + "name": "Wan2.2 Image2video Palingenesis 14B", + "architecture" : "i2v_2_2", + "description": "Wan 2.2 Image 2 Video model. Contrary to the Wan Image2video 2.1 this model is structurally close to the t2v model. Palingenesis a finetune praised for its high quality.", + "URLs": [ "https://huggingface.co/eddy1111111/WAN22.XX_Palingenesis/resolve/main/WAN22.XX_Palingenesis_high_i2v_fix.safetensors"], + "URLs2": [ "https://huggingface.co/eddy1111111/WAN22.XX_Palingenesis/resolve/main/WAN22.XX_Palingenesis_low_i2v_fix.safetensors"], + "group": "wan2_2" + }, + "ignore_unused_weights": true, + "guidance_phases": 2, + "switch_threshold" : 900, + "guidance_scale" : 3.5, + "guidance2_scale" : 3.5, + "flow_shift" : 5 + +} \ No newline at end of file diff --git a/defaults/t2v_lighting_palingenesis_2_2.json b/defaults/t2v_lighting_palingenesis_2_2.json new file mode 100644 index 000000000..e52119b4f --- /dev/null +++ b/defaults/t2v_lighting_palingenesis_2_2.json @@ -0,0 +1,25 @@ +{ + "model": + { + "name": "Wan2.2 Text2video Lightning Palingenesis 14B", + "architecture" : "t2v", + "description": "Wan 2.2 Text 2 Video Lightning Dyno model. Palingenesis a finetune praised for its high quality is used for the Low noise model whereas the High Noise model uses Lightning Finetune that offers natively Loras accelerators. ", + "URLs": [ + "https://huggingface.co/lightx2v/Wan2.2-Lightning/resolve/main/Wan2.2-T2V-A14B-4steps-250928-dyno/Wan2.2-T2V-A14B-4steps-250928-dyno-high-lightx2v.safetensors" + ], + "URLs2": ["https://huggingface.co/eddy1111111/WAN22.XX_Palingenesis/resolve/main/WAN22.XX_Palingenesis_low_t2v.safetensors"], + "loras_multipliers": ["0;1"], + "activated_loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" + ], + "profiles_dir": [ "" ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "switch_threshold" : 875, + "guidance_scale" : 1, + "guidance2_scale" : 1, + "num_inference_steps": 4, + "flow_shift" : 3 + +} \ No newline at end of file diff --git a/defaults/t2v_palingenesis_2_2.json b/defaults/t2v_palingenesis_2_2.json new file mode 100644 index 000000000..e742cffab --- /dev/null +++ b/defaults/t2v_palingenesis_2_2.json @@ -0,0 +1,15 @@ +{ + "model": + { + "name": "Wan2.2 Text2video Palingenesis 14B", + "architecture" : "t2v_2_2", + "description": "Wan 2.2 Text 2 Video Palingenesis a finetune praised for its high quality.", + "URLs": ["https://huggingface.co/eddy1111111/WAN22.XX_Palingenesis/resolve/main/WAN22.XX_Palingenesis_high_t2v.safetensors"], + "URLs2": ["https://huggingface.co/eddy1111111/WAN22.XX_Palingenesis/resolve/main/WAN22.XX_Palingenesis_low_t2v.safetensors"], + "group": "wan2_2" + }, + "guidance_phases": 2, + "switch_threshold" : 875, + "flow_shift" : 3 + +} \ No newline at end of file From 222659ccd4a21d357366cd436f11b0228d4adccf Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 7 Oct 2025 19:23:22 +0200 Subject: [PATCH 079/155] some feature for some future --- defaults/animate.json | 2 +- defaults/fantasy.json | 2 +- defaults/flux_dev_umo.json | 4 +- defaults/flux_dev_uso.json | 4 +- defaults/flux_krea.json | 2 +- defaults/flux_srpo.json | 2 +- defaults/flux_srpo_uso.json | 4 +- defaults/hunyuan_t2v_accvideo.json | 2 +- defaults/hunyuan_t2v_fast.json | 2 +- defaults/infinitetalk.json | 2 +- defaults/infinitetalk_multi.json | 2 +- defaults/lucy_edit_fastwan.json | 2 +- defaults/multitalk.json | 2 +- defaults/multitalk_720p.json | 2 +- defaults/t2v_2_2.json | 2 +- defaults/t2v_lighting_palingenesis_2_2.json | 2 +- defaults/ti2v_2_2_fastwan.json | 2 +- defaults/vace_1.3B.json | 2 +- defaults/vace_14B.json | 2 +- defaults/vace_14B_2_2.json | 5 +- defaults/vace_fun_14B_cocktail_2_2.json | 2 +- models/wan/wan_handler.py | 7 +- wgp.py | 255 +++++++++++++++++--- 23 files changed, 246 insertions(+), 67 deletions(-) diff --git a/defaults/animate.json b/defaults/animate.json index bf45f4a92..bdcb6fefd 100644 --- a/defaults/animate.json +++ b/defaults/animate.json @@ -1,6 +1,6 @@ { "model": { - "name": "Wan2.2 Animate", + "name": "Wan2.2 Animate 14B", "architecture": "animate", "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'Animation' or 'Replacement' mode. Sliding Window of 81 frames at least are recommeded to obtain the best Style continuity.", "URLs": [ diff --git a/defaults/fantasy.json b/defaults/fantasy.json index 3917caba8..fc09cee9f 100644 --- a/defaults/fantasy.json +++ b/defaults/fantasy.json @@ -1,7 +1,7 @@ { "model": { - "name": "Fantasy Talking 720p", + "name": "Fantasy Talking 720p 14B", "architecture" : "fantasy", "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_fantasy_speaking_14B_bf16.safetensors"]], "description": "The Fantasy Talking model corresponds to the original Wan image 2 video model combined with the Fantasy Speaking module to process an audio Input.", diff --git a/defaults/flux_dev_umo.json b/defaults/flux_dev_umo.json index ea7368053..a5e8e7d92 100644 --- a/defaults/flux_dev_umo.json +++ b/defaults/flux_dev_umo.json @@ -1,8 +1,8 @@ { "model": { - "name": "Flux 1 Dev UMO 12B", + "name": "Flux 1 UMO Dev 12B", "architecture": "flux_dev_umo", - "description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.", + "description": "FLUX.1 UMO Dev is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.", "URLs": "flux", "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"], "resolutions": [ ["1024x1024 (1:1)", "1024x1024"], diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json index fee8f25f9..4b429210e 100644 --- a/defaults/flux_dev_uso.json +++ b/defaults/flux_dev_uso.json @@ -1,8 +1,8 @@ { "model": { - "name": "Flux 1 Dev USO 12B", + "name": "Flux 1 USO Dev 12B", "architecture": "flux_dev_uso", - "description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).", + "description": "FLUX.1 USO Dev is a model that can Edit Images with a specialization in Style Transfers (up to two).", "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], "URLs": "flux", "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"] diff --git a/defaults/flux_krea.json b/defaults/flux_krea.json index 4fa0b3fb1..669e1a595 100644 --- a/defaults/flux_krea.json +++ b/defaults/flux_krea.json @@ -1,6 +1,6 @@ { "model": { - "name": "Flux 1 Krea Dev 12B", + "name": "Flux 1 Dev Krea 12B", "architecture": "flux", "description": "Cutting-edge output quality, with a focus on aesthetic photography..", "URLs": [ diff --git a/defaults/flux_srpo.json b/defaults/flux_srpo.json index 188c162e4..8b1c447e9 100644 --- a/defaults/flux_srpo.json +++ b/defaults/flux_srpo.json @@ -1,6 +1,6 @@ { "model": { - "name": "Flux 1 SRPO Dev 12B", + "name": "Flux 1 Dev SRPO 12B", "architecture": "flux", "description": "By fine-tuning the FLUX.1.dev model with optimized denoising and online reward adjustment, SRPO improves its human-evaluated realism and aesthetic quality by over 3x.", "URLs": [ diff --git a/defaults/flux_srpo_uso.json b/defaults/flux_srpo_uso.json index e43371daf..eed67bd36 100644 --- a/defaults/flux_srpo_uso.json +++ b/defaults/flux_srpo_uso.json @@ -1,8 +1,8 @@ { "model": { - "name": "Flux 1 SRPO USO 12B", + "name": "Flux 1 USO SRPO 12B", "architecture": "flux_dev_uso", - "description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process", + "description": "FLUX.1 USO SRPO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process", "modules": [ "flux_dev_uso"], "URLs": "flux_srpo", "loras": "flux_dev_uso" diff --git a/defaults/hunyuan_t2v_accvideo.json b/defaults/hunyuan_t2v_accvideo.json index 23309d0ea..2da984a8d 100644 --- a/defaults/hunyuan_t2v_accvideo.json +++ b/defaults/hunyuan_t2v_accvideo.json @@ -1,6 +1,6 @@ { "model": { - "name": "Hunyuan Video AccVideo 720p 13B", + "name": "Hunyuan Video Text2video 720p AccVideo 13B", "architecture": "hunyuan", "description": " AccVideo is a novel efficient distillation method to accelerate video diffusion models with synthetic datset. Our method is 8.5x faster than HunyuanVideo.", "URLs": [ diff --git a/defaults/hunyuan_t2v_fast.json b/defaults/hunyuan_t2v_fast.json index 30adb17a0..4019e24ef 100644 --- a/defaults/hunyuan_t2v_fast.json +++ b/defaults/hunyuan_t2v_fast.json @@ -1,6 +1,6 @@ { "model": { - "name": "Hunyuan Video FastHunyuan 720p 13B", + "name": "Hunyuan Video Text2video 720p FastHunyuan 13B", "architecture": "hunyuan", "description": "Fast Hunyuan is an accelerated HunyuanVideo model. It can sample high quality videos with 6 diffusion steps.", "settings_dir": [ "" ], diff --git a/defaults/infinitetalk.json b/defaults/infinitetalk.json index 7fddea91a..fc28d96e5 100644 --- a/defaults/infinitetalk.json +++ b/defaults/infinitetalk.json @@ -1,6 +1,6 @@ { "model": { - "name": "Infinitetalk Single Speaker 480p", + "name": "Infinitetalk Single Speaker 480p 14B", "architecture": "infinitetalk", "modules": [ [ diff --git a/defaults/infinitetalk_multi.json b/defaults/infinitetalk_multi.json index 97251e805..229ecc778 100644 --- a/defaults/infinitetalk_multi.json +++ b/defaults/infinitetalk_multi.json @@ -1,6 +1,6 @@ { "model": { - "name": "Infinitetalk Multi Speakers 480p", + "name": "Infinitetalk Multi Speakers 480p 14B", "architecture": "infinitetalk", "modules": [ [ diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json index 2a52166fc..de2830c9f 100644 --- a/defaults/lucy_edit_fastwan.json +++ b/defaults/lucy_edit_fastwan.json @@ -1,6 +1,6 @@ { "model": { - "name": "Wan2.2 FastWan Lucy Edit 5B", + "name": "Wan2.2 Lucy Edit FastWan 5B", "architecture": "lucy_edit", "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", "URLs": "lucy_edit", diff --git a/defaults/multitalk.json b/defaults/multitalk.json index 1133657d9..41699b584 100644 --- a/defaults/multitalk.json +++ b/defaults/multitalk.json @@ -1,7 +1,7 @@ { "model": { - "name": "Multitalk 480p", + "name": "Multitalk 480p 14B", "architecture" : "multitalk", "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_multitalk_14B_mbf16.safetensors", diff --git a/defaults/multitalk_720p.json b/defaults/multitalk_720p.json index 4bdaabcce..f18bebc01 100644 --- a/defaults/multitalk_720p.json +++ b/defaults/multitalk_720p.json @@ -1,7 +1,7 @@ { "model": { - "name": "Multitalk 720p", + "name": "Multitalk 720p 14B", "architecture" : "multitalk", "modules": ["multitalk"], "description": "The Multitalk model corresponds to the original Wan image 2 video 720p model combined with the Multitalk module. It lets you have up to two people have a conversation.", diff --git a/defaults/t2v_2_2.json b/defaults/t2v_2_2.json index 122eeb91b..806a1bfe6 100644 --- a/defaults/t2v_2_2.json +++ b/defaults/t2v_2_2.json @@ -2,7 +2,7 @@ "model": { "name": "Wan2.2 Text2video 14B", - "architecture" : "t2v", + "architecture" : "t2v_2_2", "description": "Wan 2.2 Text 2 Video model", "URLs": [ "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_text2video_14B_high_mbf16.safetensors", diff --git a/defaults/t2v_lighting_palingenesis_2_2.json b/defaults/t2v_lighting_palingenesis_2_2.json index e52119b4f..f14f03aea 100644 --- a/defaults/t2v_lighting_palingenesis_2_2.json +++ b/defaults/t2v_lighting_palingenesis_2_2.json @@ -2,7 +2,7 @@ "model": { "name": "Wan2.2 Text2video Lightning Palingenesis 14B", - "architecture" : "t2v", + "architecture" : "t2v_2_2", "description": "Wan 2.2 Text 2 Video Lightning Dyno model. Palingenesis a finetune praised for its high quality is used for the Low noise model whereas the High Noise model uses Lightning Finetune that offers natively Loras accelerators. ", "URLs": [ "https://huggingface.co/lightx2v/Wan2.2-Lightning/resolve/main/Wan2.2-T2V-A14B-4steps-250928-dyno/Wan2.2-T2V-A14B-4steps-250928-dyno-high-lightx2v.safetensors" diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json index c142dcbd2..eeb2cffbf 100644 --- a/defaults/ti2v_2_2_fastwan.json +++ b/defaults/ti2v_2_2_fastwan.json @@ -1,6 +1,6 @@ { "model": { - "name": "Wan2.2 FastWan TextImage2video 5B", + "name": "Wan2.2 TextImage2video FastWan 5B", "architecture": "ti2v_2_2", "description": "FastWan2.2-TI2V-5B-Full-Diffusers is built upon Wan-AI/Wan2.2-TI2V-5B-Diffusers. It supports efficient 3-step inference and produces high-quality videos at 121×704×1280 resolution", "URLs": "ti2v_2_2", diff --git a/defaults/vace_1.3B.json b/defaults/vace_1.3B.json index 406b9a7c6..8b18f4584 100644 --- a/defaults/vace_1.3B.json +++ b/defaults/vace_1.3B.json @@ -1,7 +1,7 @@ { "model": { - "name": "Vace ControlNet 1.3B", + "name": "Vace 1.3B", "architecture" : "vace_1.3B", "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_1_3B_module.safetensors"] diff --git a/defaults/vace_14B.json b/defaults/vace_14B.json index b6396641e..0304db71b 100644 --- a/defaults/vace_14B.json +++ b/defaults/vace_14B.json @@ -1,6 +1,6 @@ { "model": { - "name": "Vace ControlNet 14B", + "name": "Vace 14B", "architecture": "vace_14B", "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_Vace_14B_module_mbf16.safetensors", diff --git a/defaults/vace_14B_2_2.json b/defaults/vace_14B_2_2.json index 6aa8cf8fb..e93f96aaa 100644 --- a/defaults/vace_14B_2_2.json +++ b/defaults/vace_14B_2_2.json @@ -1,14 +1,13 @@ { "model": { - "name": "Wan2.2 Vace Experimental 14B", + "name": "Wan2.2 Vace 14B", "architecture": "vace_14B_2_2", "modules": [ "vace_14B" ], "description": "There is so far only PARTIAL support of Vace 2.1 which is currently used.", "URLs": "t2v_2_2", - "URLs2": "t2v_2_2", - "visible": false + "URLs2": "t2v_2_2" }, "guidance_phases": 2, "guidance_scale": 1, diff --git a/defaults/vace_fun_14B_cocktail_2_2.json b/defaults/vace_fun_14B_cocktail_2_2.json index 0a734380f..b50ddda35 100644 --- a/defaults/vace_fun_14B_cocktail_2_2.json +++ b/defaults/vace_fun_14B_cocktail_2_2.json @@ -1,7 +1,7 @@ { "model": { "name": "Wan2.2 Vace Fun Cocktail 14B", - "architecture": "vace_14B", + "architecture": "vace_14B_2_2", "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. This is the Fun Vace 2.2, that is not the official Vace 2.2", "URLs": "vace_fun_14B_2_2", "URLs2": "vace_fun_14B_2_2", diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index dfc416e6b..da0bdc454 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -126,15 +126,18 @@ def query_model_def(base_model_type, model_def): extra_model_def["lynx_class"] = test_lynx(base_model_type) vace_class = base_model_type in ["vace_14B", "vace_14B_2_2", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] extra_model_def["vace_class"] = vace_class + if base_model_type in ["vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"]: + extra_model_def["parent_model_type"] = "vace_14B" + group = "wan" if base_model_type in ["t2v_2_2", "i2v_2_2", "vace_14B_2_2"]: profiles_dir = "wan_2_2" - group = "wan_2_2" + group = "wan2_2" elif i2v: profiles_dir = "wan_i2v" elif test_wan_5B(base_model_type): profiles_dir = "wan_2_2_5B" - group = "wan_2_2" + group = "wan2_2" elif test_class_1_3B(base_model_type): profiles_dir = "wan_1.3B" else: diff --git a/wgp.py b/wgp.py index ff0a51ddf..1ce0824e0 100644 --- a/wgp.py +++ b/wgp.py @@ -54,6 +54,7 @@ from tqdm import tqdm import requests from shared.gradio.gallery import AdvancedMediaGallery +from collections import defaultdict # import torch._dynamo as dynamo # dynamo.config.recompile_limit = 2000 # default is 256 @@ -1839,6 +1840,7 @@ def get_lora_dir(model_type): checkpoints_paths = server_config.get("checkpoints_paths", None) if checkpoints_paths is None: checkpoints_paths = server_config["checkpoints_paths"] = fl.default_checkpoints_paths fl.set_checkpoints_paths(checkpoints_paths) +three_levels_hierarchy = server_config.get("model_hierarchy_type", 1) == 1 # Deprecated models for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors", @@ -1897,6 +1899,12 @@ def get_base_model_type(model_type): else: return model_def["architecture"] +def get_parent_model_type(model_type): + base_model_type = get_base_model_type(model_type) + if base_model_type is None: return None + model_def = get_model_def(base_model_type) + return model_def.get("parent_model_type", base_model_type) +"vace_14B" def get_model_handler(model_type): base_model_type = get_base_model_type(model_type) if base_model_type is None: @@ -2993,6 +3001,7 @@ def release_RAM(): def apply_changes( state, transformer_types_choices, + model_hierarchy_type_choice, transformer_dtype_policy_choice, text_encoder_quantization_choice, VAE_precision_choice, @@ -3029,13 +3038,14 @@ def apply_changes( state, checkpoints_paths = [path.strip() for path in checkpoints_paths if len(path.strip()) > 0 ] fl.set_checkpoints_paths(checkpoints_paths) if args.lock_config: - return "
Config Locked
",*[gr.update()]*4 + return "
Config Locked
",*[gr.update()]*5 if gen_in_progress: - return "
Unable to change config when a generation is in progress
",*[gr.update()]*4 + return "
Unable to change config when a generation is in progress
",*[gr.update()]*5 global offloadobj, wan_model, server_config, loras, default_loras_choices, default_loras_multis_str, default_lora_preset_prompt, default_lora_preset, loras_presets server_config = { "attention_mode" : attention_choice, "transformer_types": transformer_types_choices, + "model_hierarchy_type": model_hierarchy_type_choice, "text_encoder_quantization" : text_encoder_quantization_choice, "save_path" : save_path_choice, "image_save_path" : image_save_path_choice, @@ -3066,6 +3076,7 @@ def apply_changes( state, "audio_output_codec" : audio_output_codec_choice, "last_model_type" : state["model_type"], "last_model_per_family": state["last_model_per_family"], + "last_model_per_type": state["last_model_per_type"], "last_advanced_choice": state["advanced"], "last_resolution_choice": last_resolution_choice, "last_resolution_per_group": state["last_resolution_per_group"], @@ -3090,7 +3101,8 @@ def apply_changes( state, if v != v_old: changes.append(k) - global attention_mode, default_profile, compile, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_dtype_policy, transformer_types, text_encoder_quantization, save_path + global attention_mode, default_profile, compile, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_dtype_policy, transformer_types, text_encoder_quantization, save_path , three_levels_hierarchy + three_levels_hierarchy= server_config["model_hierarchy_type"] == 1 attention_mode = server_config["attention_mode"] default_profile = server_config["profile"] compile = server_config["compile"] @@ -3104,22 +3116,24 @@ def apply_changes( state, transformer_dtype_policy = server_config["transformer_dtype_policy"] text_encoder_quantization = server_config["text_encoder_quantization"] transformer_types = server_config["transformer_types"] - model_filename = get_model_filename(transformer_type, transformer_quantization, transformer_dtype_policy) + model_type = state["model_type"] + model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) state["model_filename"] = model_filename if "enhancer_enabled" in changes or "enhancer_mode" in changes: reset_prompt_enhancer() if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats", - "video_output_codec", "image_output_codec", "audio_output_codec", "checkpoints_paths"] for change in changes ): + "video_output_codec", "image_output_codec", "audio_output_codec", "checkpoints_paths", "model_hierarchy_type_choice"] for change in changes ): model_family = gr.Dropdown() + model_base_type = gr.Dropdown() model_choice = gr.Dropdown() else: reload_needed = True - model_family, model_choice = generate_dropdown_model_list(transformer_type) + model_family, model_base_type, model_choice = generate_dropdown_model_list(model_type) header = generate_header(state["model_type"], compile=compile, attention_mode= attention_mode) mmaudio_enabled = server_config["mmaudio_enabled"] > 0 - return "
The new configuration has been succesfully applied
", header, model_family, model_choice, get_unique_id() + return "
The new configuration has been succesfully applied
", header, model_family, model_base_type, model_choice, get_unique_id() def get_gen_info(state): cache = state.get("gen", None) @@ -5877,7 +5891,7 @@ def compute_lset_choices(model_type, loras_presets): indent = chr(160) * 4 lset_choices = [] if len(global_list) > 0: - lset_choices += [( (sep*16) +"Profiles" + (sep*17), ">Profiles")] + lset_choices += [( (sep*12) +"Accelerators Profiles" + (sep*13), ">profiles")] lset_choices += [ ( indent + os.path.splitext(os.path.basename(preset))[0], preset) for preset in global_list ] if len(settings_list) > 0: settings_list.sort() @@ -6089,14 +6103,14 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m state["apply_success"] = 1 wizard_prompt_activated = "on" - return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, get_unique_id(), gr.update(), gr.update(), gr.update() + return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, get_unique_id(), gr.update(), gr.update(), gr.update(), gr.update() else: lset_path = os.path.join("profiles" if len(Path(lset_name).parts)>1 else get_lora_dir(current_model_type), lset_name) configs, _ = get_settings_from_file(state,lset_path , True, True, True, min_settings_version=2.38, merge_loras = "merge after" if len(Path(lset_name).parts)<=1 else "merge before" ) if configs == None: gr.Info("File not supported") - return [gr.update()] * 8 + return [gr.update()] * 9 model_type = configs["model_type"] configs["lset_name"] = lset_name gr.Info(f"Settings File '{lset_name}' has been applied") @@ -6104,7 +6118,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m if help is not None: gr.Info(help) if model_type == current_model_type: set_model_settings(state, current_model_type, configs) - return *[gr.update()] * 4, gr.update(), gr.update(), gr.update(), get_unique_id() + return *[gr.update()] * 4, gr.update(), gr.update(), gr.update(), gr.update(), get_unique_id() else: set_model_settings(state, model_type, configs) return *[gr.update()] * 4, gr.update(), *generate_dropdown_model_list(model_type), gr.update() @@ -6755,12 +6769,12 @@ def load_settings_from_file(state, file_path): gen = get_gen_info(state) if file_path==None: - return gr.update(), gr.update(), None + return gr.update(), gr.update(), gr.update(), gr.update(), None configs, any_video_or_image_file = get_settings_from_file(state, file_path, True, True, True) if configs == None: gr.Info("File not supported") - return gr.update(), gr.update(), None + return gr.update(), gr.update(), gr.update(), gr.update(), None current_model_type = state["model_type"] model_type = configs["model_type"] @@ -6774,7 +6788,7 @@ def load_settings_from_file(state, file_path): if model_type == current_model_type: set_model_settings(state, current_model_type, configs) - return gr.update(), gr.update(), str(time.time()), None + return gr.update(), gr.update(), gr.update(), gr.update(), str(time.time()), None else: set_model_settings(state, model_type, configs) return *generate_dropdown_model_list(model_type), gr.update(), None @@ -6876,7 +6890,6 @@ def save_inputs( ): - model_filename = state["model_filename"] model_type = state["model_type"] if image_mask_guide is not None and image_mode >= 1 and video_prompt_type is not None and "A" in video_prompt_type and not "U" in video_prompt_type: # if image_mask_guide is not None and image_mode == 2: @@ -6981,6 +6994,11 @@ def change_model(state, model_choice): last_model_per_family = state["last_model_per_family"] last_model_per_family[get_model_family(model_choice, for_ui= True)] = model_choice server_config["last_model_per_family"] = last_model_per_family + + last_model_per_type = state["last_model_per_type"] + last_model_per_type[get_base_model_type(model_choice)] = model_choice + server_config["last_model_per_type"] = last_model_per_type + server_config["last_model_type"] = model_choice with open(server_config_filename, "w", encoding="utf-8") as writer: @@ -7500,7 +7518,7 @@ def download_lora(state, lora_url, progress=gr.Progress(track_tqdm=True),): -def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): +def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_base_type_choice = None, model_choice = None, header = None, main = None, main_tabs= None): global inputs_names #, advanced if update_form: @@ -7517,6 +7535,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non state_dict["model_type"] = model_type state_dict["advanced"] = advanced_ui state_dict["last_model_per_family"] = server_config.get("last_model_per_family", {}) + state_dict["last_model_per_type"] = server_config.get("last_model_per_type", {}) state_dict["last_resolution_per_group"] = server_config.get("last_resolution_per_group", {}) gen = dict() gen["queue"] = [] @@ -8622,7 +8641,7 @@ def activate_status(state): ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None - ).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_family, model_choice, refresh_form_trigger]) + ).then( fn=use_video_settings, inputs =[state, output, last_choice] , outputs= [model_family, model_base_type_choice, model_choice, refresh_form_trigger]) prompt_enhancer_btn.click(fn=validate_wizard_prompt, @@ -8667,7 +8686,7 @@ def activate_status(state): confirm_delete_lset_btn.click(delete_lset, inputs=[state, lset_name], outputs=[lset_name, apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) cancel_lset_btn.click(cancel_lset, inputs=[], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn, confirm_delete_lset_btn,confirm_save_lset_btn, cancel_lset_btn,save_lset_prompt_drop ]) apply_lset_btn.click(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None).then(fn=apply_lset, - inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_multipliers, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_multipliers, prompt, fill_wizard_prompt_trigger, model_family, model_choice, refresh_form_trigger]) + inputs=[state, wizard_prompt_activated_var, lset_name,loras_choices, loras_multipliers, prompt], outputs=[wizard_prompt_activated_var, loras_choices, loras_multipliers, prompt, fill_wizard_prompt_trigger, model_family, model_base_type_choice, model_choice, refresh_form_trigger]) refresh_lora_btn.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) refresh_lora_btn2.click(refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) @@ -8706,7 +8725,7 @@ def activate_status(state): ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None - ).then(fn=load_settings_from_file, inputs =[state, settings_file] , outputs= [model_family, model_choice, refresh_form_trigger, settings_file]) + ).then(fn=load_settings_from_file, inputs =[state, settings_file] , outputs= [model_family, model_base_type_choice, model_choice, refresh_form_trigger, settings_file]) reset_settings_btn.click(fn=validate_wizard_prompt, @@ -8735,7 +8754,8 @@ def activate_status(state): show_progress="hidden", ) - model_family.input(fn=change_model_family, inputs=[state, model_family], outputs= [model_choice]) + model_family.input(fn=change_model_family, inputs=[state, model_family], outputs= [model_base_type_choice, model_choice], show_progress="hidden") + model_base_type_choice.input(fn=change_model_base_types, inputs=[state, model_family, model_base_type_choice], outputs= [model_base_type_choice, model_choice], show_progress="hidden") model_choice.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , @@ -8923,13 +8943,13 @@ def generate_download_tab(lset_name,loras_choices, state): download_loras_btn.click(fn=download_loras, inputs=[], outputs=[download_status_row, download_status]).then(fn=refresh_lora_list, inputs=[state, lset_name,loras_choices], outputs=[lset_name, loras_choices]) -def generate_configuration_tab(state, blocks, header, model_family, model_choice, resolution, refresh_form_trigger): +def generate_configuration_tab(state, blocks, header, model_family, model_base_type_choice, model_choice, resolution, refresh_form_trigger): gr.Markdown("Please click Apply Changes at the bottom so that the changes are effective. Some choices below may be locked if the app has been launched by specifying a config preset.") with gr.Column(): with gr.Tabs(): # with gr.Row(visible=advanced_ui) as advanced_row: with gr.Tab("General"): - dropdown_families, dropdown_choices = get_sorted_dropdown(displayed_model_types, None) + _, _, dropdown_choices = get_sorted_dropdown(displayed_model_types, None, None, False) transformer_types_choices = gr.Dropdown( choices= dropdown_choices, @@ -8939,6 +8959,17 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice multiselect= True ) + model_hierarchy_type_choice = gr.Dropdown( + choices=[ + ("Two Levels: Model Family > Models & Finetunes", 0), + ("Three Levels: Model Family > Models > Finetunes", 1), + ], + value= server_config.get("model_hierarchy_type", 1), + label= "Models Hierarchy In User Interface", + scale= 2, + ) + + fit_canvas_choice = gr.Dropdown( choices=[ ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be Resized to match this pixels Budget, output video height or width may exceed the requested dimensions )", 0), @@ -9241,6 +9272,7 @@ def check(mode): inputs=[ state, transformer_types_choices, + model_hierarchy_type_choice, transformer_dtype_policy_choice, text_encoder_quantization_choice, VAE_precision_choice, @@ -9273,7 +9305,7 @@ def check(mode): resolution, checkpoints_paths, ], - outputs= [msg , header, model_family, model_choice, refresh_form_trigger] + outputs= [msg , header, model_family, model_base_type_choice, model_choice, refresh_form_trigger] ) def generate_about_tab(): @@ -9331,7 +9363,100 @@ def compact_name(family_name, model_name): return model_name[len(family_name):].strip() return model_name -def get_sorted_dropdown(dropdown_types, current_model_family): + +def create_models_hierarchy(rows): + """ + rows: list of (model_name, model_id, parent_model_id) + returns: + parents_list: list[(parent_header, parent_id)] + children_dict: dict[parent_id] -> list[(child_display_name, child_id)] + """ + toks=lambda s:[t for t in s.split() if t] + norm=lambda s:' '.join(s.split()).casefold() + + groups,parents,order=defaultdict(list),{},[] + for name,mid,pmid in rows: + groups[pmid].append((name,mid)) + if mid==pmid and pmid not in parents: + parents[pmid]=name; order.append(pmid) + + parents_list,children_dict=[],{} + + # --- Real parents --- + for pid in order: + p_name=parents[pid]; p_tok=toks(p_name); p_low=[w.casefold() for w in p_tok] + n=len(p_low); p_last=p_low[-1]; p_set=set(p_low) + + kids=[] + for name,mid in groups.get(pid,[]): + ot=toks(name); lt=[w.casefold() for w in ot]; st=set(lt) + kids.append((name,mid,ot,lt,st)) + + outliers={mid for _,mid,_,_,st in kids if mid!=pid and p_set.isdisjoint(st)} + + # Only parent + children that start with parent's first word contribute to prefix + prefix_non=[] + for name,mid,ot,lt,st in kids: + if mid==pid or (mid not in outliers and lt and lt[0]==p_low[0]): + prefix_non.append((ot,lt)) + + def lcp_len(a,b): + i=0; m=min(len(a),len(b)) + while i1: L=n + + shares_last=any(mid!=pid and mid not in outliers and lt and lt[-1]==p_last + for _,mid,_,lt,_ in kids) + header_tokens_disp=p_tok[:L]+([p_tok[-1]] if shares_last and L 0 else displayed_model_types if current_model_type not in dropdown_types: dropdown_types.append(current_model_type) current_model_family = get_model_family(current_model_type, for_ui= True) - sorted_familes, dropdown_choices = get_sorted_dropdown(dropdown_types, current_model_family) + sorted_familes, sorted_models, sorted_finetunes = get_sorted_dropdown(dropdown_types, current_model_family, current_model_type, three_levels=three_levels_hierarchy) dropdown_families = gr.Dropdown( choices= sorted_familes, value= current_model_family, show_label= False, - scale= 1, + scale= 2 if three_levels_hierarchy else 1, elem_id="family_list", min_width=50 ) - return dropdown_families, gr.Dropdown( - choices= dropdown_choices, + dropdown_models = gr.Dropdown( + choices= sorted_models, + value= get_base_model_type(current_model_type), + show_label= False, + scale= 3 if len(sorted_finetunes) > 1 else 7, + elem_id="model_base_types_list", + visible= three_levels_hierarchy + ) + + dropdown_finetunes = gr.Dropdown( + choices= sorted_finetunes, value= current_model_type, show_label= False, scale= 4, + visible= len(sorted_finetunes) > 1, elem_id="model_list", ) + + return dropdown_families, dropdown_models, dropdown_finetunes def change_model_family(state, current_model_family): dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types @@ -9381,7 +9527,33 @@ def change_model_family(state, current_model_family): last_model_per_family = state.get("last_model_per_family", {}) model_type = last_model_per_family.get(current_model_family, "") if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices] : model_type = dropdown_choices[0][1] - return gr.Dropdown(choices= dropdown_choices, value = model_type ) + + parent_model_type = get_parent_model_type(model_type) + if three_levels_hierarchy: + dropdown_choices = [ (*tup, get_parent_model_type(tup[1])) for tup in dropdown_choices] + dropdown_base_types_choices, finetunes_dict = create_models_hierarchy(dropdown_choices) + dropdown_choices = finetunes_dict[parent_model_type ] + model_finetunes_visible = len(dropdown_choices) > 1 + else: + model_finetunes_visible = True + dropdown_base_types_choices = list({get_parent_model_type(model[1]) for model in dropdown_choices}) + + return gr.Dropdown(choices= dropdown_base_types_choices, value = parent_model_type, scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices= dropdown_choices, value = model_type, visible = model_finetunes_visible ) + +def change_model_base_types(state, current_model_family, model_base_type_choice): + if not three_levels_hierarchy: return gr.update() + dropdown_types= transformer_types if len(transformer_types) > 0 else displayed_model_types + current_family_name = families_infos[current_model_family][1] + dropdown_choices = [ (compact_name(current_family_name, get_model_name(model_type)), model_type, model_base_type_choice) for model_type in dropdown_types if get_parent_model_type(model_type) == model_base_type_choice and get_model_family(model_type, for_ui= True) == current_model_family] + dropdown_choices = sorted(dropdown_choices, key=lambda c: c[0]) + _, finetunes_dict = create_models_hierarchy(dropdown_choices) + dropdown_choices = finetunes_dict[model_base_type_choice ] + model_finetunes_visible = len(dropdown_choices) > 1 + last_model_per_type = state.get("last_model_per_type", {}) + model_type = last_model_per_type.get(model_base_type_choice, "") + if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices] : model_type = dropdown_choices[0][1] + + return gr.update(scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices= dropdown_choices, value = model_type, visible=model_finetunes_visible ) def set_new_tab(tab_state, new_tab_no): global vmc_event_handler @@ -9547,11 +9719,13 @@ def create_ui(): --layout-gap: 0px !important; } .postprocess span {margin-top:4px;margin-bottom:4px} - #model_list, #family_list{ + #model_list, #family_list, #model_base_types_list { background-color:black; padding:1px} - #model_list input, #family_list input { + #model_list,#model_base_types_list { padding-left:0px} + + #model_list input, #family_list input, #model_base_types_list input { font-size:25px} #family_list div div { @@ -9562,6 +9736,10 @@ def create_ui(): border-radius: 0px 4px 4px 0px; } + #model_base_types_list div div { + border-radius: 0px 0px 0px 0px; + } + .title-with-lines { display: flex; align-items: center; @@ -9917,8 +10095,8 @@ def create_ui(): model_family = gr.Dropdown(visible=False, value= "") model_choice = gr.Dropdown(visible=False, value= transformer_type, choices= [transformer_type]) else: - gr.Markdown("
") - model_family, model_choice = generate_dropdown_model_list(transformer_type) + gr.Markdown("
") + model_family, model_base_type_choice, model_choice = generate_dropdown_model_list(transformer_type) gr.Markdown("
") with gr.Row(): header = gr.Markdown(generate_header(transformer_type, compile, attention_mode), visible= True) @@ -9927,17 +10105,16 @@ def create_ui(): with gr.Row(): ( state, loras_choices, lset_name, resolution, refresh_form_trigger, save_form_trigger - # video_guide, image_guide, video_mask, image_mask, image_refs, - ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) + ) = generate_video_tab(model_family=model_family, model_base_type_choice= model_base_type_choice, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: - matanyone_app.display(main_tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings) #, video_guide, image_guide, video_mask, image_mask, image_refs) + matanyone_app.display(main_tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings) if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) with gr.Tab("Configuration", id="configuration") as configuration_tab: - generate_configuration_tab(state, main, header, model_family, model_choice, resolution, refresh_form_trigger) + generate_configuration_tab(state, main, header, model_family, model_base_type_choice, model_choice, resolution, refresh_form_trigger) with gr.Tab("About"): generate_about_tab() if stats_app is not None: From 57c1a953582f4503718d967938ecacff9d1caeae Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 7 Oct 2025 19:23:42 +0200 Subject: [PATCH 080/155] some feature for some future --- wgp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgp.py b/wgp.py index 1ce0824e0..099c2103a 100644 --- a/wgp.py +++ b/wgp.py @@ -1840,7 +1840,7 @@ def get_lora_dir(model_type): checkpoints_paths = server_config.get("checkpoints_paths", None) if checkpoints_paths is None: checkpoints_paths = server_config["checkpoints_paths"] = fl.default_checkpoints_paths fl.set_checkpoints_paths(checkpoints_paths) -three_levels_hierarchy = server_config.get("model_hierarchy_type", 1) == 1 +three_levels_hierarchy = server_config.get("model_hierarchy_type", 0) == 1 # Deprecated models for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors", From cbab804c3be78a7c166f48542625c7bdda5e470b Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Tue, 7 Oct 2025 21:07:44 +0200 Subject: [PATCH 081/155] hidden betafeature --- wgp.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/wgp.py b/wgp.py index 099c2103a..42bd3219e 100644 --- a/wgp.py +++ b/wgp.py @@ -65,7 +65,7 @@ PROMPT_VARS_MAX = 10 target_mmgp_version = "3.6.2" -WanGP_version = "8.994" +WanGP_version = "8.995" settings_version = 2.39 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -3123,10 +3123,12 @@ def apply_changes( state, reset_prompt_enhancer() if all(change in ["attention_mode", "vae_config", "boost", "save_path", "metadata_type", "clear_file_list", "fit_canvas", "depth_anything_v2_variant", "notification_sound_enabled", "notification_sound_volume", "mmaudio_enabled", "max_frames_multiplier", "display_stats", - "video_output_codec", "image_output_codec", "audio_output_codec", "checkpoints_paths", "model_hierarchy_type_choice"] for change in changes ): + "video_output_codec", "image_output_codec", "audio_output_codec", "checkpoints_paths", "model_hierarchy_type"] for change in changes ): model_family = gr.Dropdown() model_base_type = gr.Dropdown() model_choice = gr.Dropdown() + if "model_hierarchy_type" in changes: + model_family, model_base_type, model_choice = generate_dropdown_model_list(model_type) else: reload_needed = True model_family, model_base_type, model_choice = generate_dropdown_model_list(model_type) @@ -4384,6 +4386,7 @@ def process_prompt_enhancer(prompt_enhancer, original_prompts, image_start, ori prompt_images = [] if "I" in prompt_enhancer: if image_start != None: + if not isinstance(image_start, list): image_start= [image_start] prompt_images += image_start if original_image_refs != None: prompt_images += original_image_refs[:1] @@ -8966,6 +8969,7 @@ def generate_configuration_tab(state, blocks, header, model_family, model_base_t ], value= server_config.get("model_hierarchy_type", 1), label= "Models Hierarchy In User Interface", + visible=args.betatest, scale= 2, ) @@ -9476,7 +9480,7 @@ def get_sorted_dropdown(dropdown_types, current_model_family, current_model_type return sorted_familes, sorted_choices, finetunes_dict[get_parent_model_type(current_model_type)] else: - dropdown_types_list = list({get_parent_model_type(model[2]) for model in dropdown_choices}) + dropdown_types_list = list({get_base_model_type(model[2]) for model in dropdown_choices}) dropdown_choices = [model[1:] for model in dropdown_choices] return sorted_familes, dropdown_types_list, dropdown_choices @@ -9500,7 +9504,7 @@ def generate_dropdown_model_list(current_model_type): dropdown_models = gr.Dropdown( choices= sorted_models, - value= get_base_model_type(current_model_type), + value= get_parent_model_type(current_model_type) if three_levels_hierarchy else get_base_model_type(current_model_type), show_label= False, scale= 3 if len(sorted_finetunes) > 1 else 7, elem_id="model_base_types_list", @@ -9528,15 +9532,16 @@ def change_model_family(state, current_model_family): model_type = last_model_per_family.get(current_model_family, "") if len(model_type) == "" or model_type not in [choice[1] for choice in dropdown_choices] : model_type = dropdown_choices[0][1] - parent_model_type = get_parent_model_type(model_type) if three_levels_hierarchy: + parent_model_type = get_parent_model_type(model_type) dropdown_choices = [ (*tup, get_parent_model_type(tup[1])) for tup in dropdown_choices] dropdown_base_types_choices, finetunes_dict = create_models_hierarchy(dropdown_choices) dropdown_choices = finetunes_dict[parent_model_type ] model_finetunes_visible = len(dropdown_choices) > 1 else: + parent_model_type = get_base_model_type(model_type) model_finetunes_visible = True - dropdown_base_types_choices = list({get_parent_model_type(model[1]) for model in dropdown_choices}) + dropdown_base_types_choices = list({get_base_model_type(model[1]) for model in dropdown_choices}) return gr.Dropdown(choices= dropdown_base_types_choices, value = parent_model_type, scale=3 if model_finetunes_visible else 7), gr.Dropdown(choices= dropdown_choices, value = model_type, visible = model_finetunes_visible ) From ef030c5a14f9011e1eea9ea2341dc1e08438986f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 8 Oct 2025 08:08:36 +0200 Subject: [PATCH 082/155] fixed Vace for 2.2 ignored --- models/wan/any2video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index b2ee944e7..ec9fc37a2 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -461,7 +461,7 @@ def generate(self, # if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0) if self._interrupt: return None - vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] + vace = model_type in ["vace_1.3B","vace_14B", "vace_14B_2_2", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] phantom = model_type in ["phantom_1.3B", "phantom_14B"] fantasy = model_type in ["fantasy"] multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] From 822f033caf0f5942e97201d605f154d6a62c8a5f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Wed, 8 Oct 2025 15:57:51 +0200 Subject: [PATCH 083/155] fixed import json settings --- wgp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/wgp.py b/wgp.py index 42bd3219e..c2dd8cd6a 100644 --- a/wgp.py +++ b/wgp.py @@ -1828,6 +1828,7 @@ def get_lora_dir(model_type): "preload_model_policy": [], "UI_theme": "default", "checkpoints_paths": fl.default_checkpoints_paths, + "model_hierarchy_type": 1, } with open(server_config_filename, "w", encoding="utf-8") as writer: @@ -1840,7 +1841,7 @@ def get_lora_dir(model_type): checkpoints_paths = server_config.get("checkpoints_paths", None) if checkpoints_paths is None: checkpoints_paths = server_config["checkpoints_paths"] = fl.default_checkpoints_paths fl.set_checkpoints_paths(checkpoints_paths) -three_levels_hierarchy = server_config.get("model_hierarchy_type", 0) == 1 +three_levels_hierarchy = server_config.get("model_hierarchy_type", 0) == 1 and args.betatest # Deprecated models for path in ["wan2.1_Vace_1.3B_preview_bf16.safetensors", "sky_reels2_diffusion_forcing_1.3B_bf16.safetensors","sky_reels2_diffusion_forcing_720p_14B_bf16.safetensors", @@ -6791,7 +6792,7 @@ def load_settings_from_file(state, file_path): if model_type == current_model_type: set_model_settings(state, current_model_type, configs) - return gr.update(), gr.update(), gr.update(), gr.update(), str(time.time()), None + return gr.update(), gr.update(), gr.update(), str(time.time()), None else: set_model_settings(state, model_type, configs) return *generate_dropdown_model_list(model_type), gr.update(), None From 472c9a002c5f3351b394041f7ad25ace966dcb4a Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 9 Oct 2025 15:56:17 +0200 Subject: [PATCH 084/155] Larger FP8 Support + various bug fixes --- models/wan/any2video.py | 34 +++++++++++++------ models/wan/modules/model.py | 12 +++++-- models/wan/wan_handler.py | 1 + .../Lightning v2025-09-28 - 4 Steps.json | 14 -------- requirements.txt | 2 +- shared/utils/loras_mutipliers.py | 2 ++ wgp.py | 21 +++++++----- 7 files changed, 49 insertions(+), 37 deletions(-) delete mode 100644 profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json diff --git a/models/wan/any2video.py b/models/wan/any2video.py index ec9fc37a2..7e0592b79 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -59,7 +59,18 @@ def timestep_transform(t, shift=5.0, num_timesteps=1000 ): new_t = new_t * num_timesteps return new_t - +def preprocess_sd(sd): + new_sd = {} + prefix_list = ["model.diffusion_model"] + for k,v in sd.items(): + for prefix in prefix_list: + if k.startswith(prefix): + k = k[len(prefix)+1:] + continue + if not k.startswith("vae."): + new_sd[k] = v + return new_sd + class WanAny2V: def __init__( @@ -132,15 +143,16 @@ def __init__( source2 = model_def.get("source2", None) module_source = model_def.get("module_source", None) module_source2 = model_def.get("module_source2", None) - kwargs= { "ignore_unused_weights": ignore_unused_weights, "writable_tensors": False, "default_dtype": dtype } + kwargs= { "modelClass": WanModel,"do_quantize": quantizeTransformer and not save_quantized, "defaultConfigPath": base_config_file , "ignore_unused_weights": ignore_unused_weights, "writable_tensors": False, "default_dtype": dtype, "preprocess_sd": preprocess_sd, "forcedConfigPath": forcedConfigPath, } + kwargs_light= { "modelClass": WanModel,"writable_tensors": False, "preprocess_sd": preprocess_sd , "forcedConfigPath" : base_config_file} if module_source is not None: - self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) + self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], **kwargs) if module_source2 is not None: - self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], **kwargs) if source is not None: - self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) + self.model = offload.fast_load_transformers_model(source, **kwargs_light) if source2 is not None: - self.model2 = offload.fast_load_transformers_model(source2, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) + self.model2 = offload.fast_load_transformers_model(source2, **kwargs_light) if self.model is not None or self.model2 is not None: from wgp import save_model @@ -152,17 +164,17 @@ def __init__( if 0 in submodel_no_list[2:]: shared_modules= {} - self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules, **kwargs) - self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], return_shared_modules= shared_modules, **kwargs) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, **kwargs) shared_modules = None else: modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ] modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ] - self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) - self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, **kwargs) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, **kwargs) else: - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, **kwargs) + self.model = offload.fast_load_transformers_model(model_filename, **kwargs) if self.model is not None: diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 2185b9cf4..8c33a17a7 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -198,7 +198,10 @@ def forward(self, x): # return F.layer_norm( # input, self.normalized_shape, self.weight, self.bias, self.eps # ) - y = super().forward(x) + if self.weight is not None: + y = super().forward(x.to(self.weight.dtype)) + else: + y = super().forward(x) x = y.type_as(x) return x # return super().forward(x).type_as(x) @@ -1141,6 +1144,8 @@ def __init__(self, def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): + from optimum.quanto import QTensor + layer_list = [self.head, self.head.head, self.patch_embedding] target_dype= dtype @@ -1174,8 +1179,9 @@ def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): for layer in current_layer_list: layer._lock_dtype = dtype - if hasattr(layer, "weight") and layer.weight.dtype != current_dtype : - layer.weight.data = layer.weight.data.to(current_dtype) + if hasattr(layer, "weight") and layer.weight.dtype != current_dtype: + if not isinstance(layer.weight.data, QTensor): + layer.weight.data = layer.weight.data.to(current_dtype) if hasattr(layer, "bias"): layer.bias.data = layer.bias.data.to(current_dtype) diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index da0bdc454..bb6745189 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -190,6 +190,7 @@ def query_model_def(base_model_type, model_def): ("Video to Video guided by Text Prompt", "GUV"), ("Video to Video guided by Text Prompt and Restricted to the Area of the Video Mask", "GVA")], "default": "", + "show_label" : False, "letters_filter": "GUVA", "label": "Video to Video" } diff --git a/profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json b/profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json deleted file mode 100644 index 31e4f0d66..000000000 --- a/profiles/wan_2_2/Lightning v2025-09-28 - 4 Steps.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "num_inference_steps": 4, - "guidance_scale": 1, - "guidance2_scale": 1, - "switch_threshold": 876, - "model_switch_phase": 1, - "guidance_phases": 2, - "flow_shift": 3, - "loras_multipliers": "1;0 0;1", - "activated_loras": [ - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_HIGH_Lightning_4steps_lora_250928_rank128_fp16.safetensors", - "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan22_A14B_T2V_LOW_Lightning_4steps_lora_250928_rank64_fp16.safetensors" - ] -} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cfe61c57f..4fa1e10f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,7 +48,7 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.6.2 +mmgp==3.6.3 peft==0.15.0 matplotlib diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py index 8ba456d4d..5e024a976 100644 --- a/shared/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -306,6 +306,8 @@ def merge_loras_settings( update that preserved multiplier and drop the duplicate from the replaced side. """ assert mode in ("merge before", "merge after") + mult_old= mult_old.strip() + mult_new= mult_new.strip() # Old split & alignment bi_old = _find_bar(mult_old) diff --git a/wgp.py b/wgp.py index c2dd8cd6a..c0b498260 100644 --- a/wgp.py +++ b/wgp.py @@ -63,9 +63,8 @@ global_queue_ref = [] AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 - -target_mmgp_version = "3.6.2" -WanGP_version = "8.995" +target_mmgp_version = "3.6.3" +WanGP_version = "8.996" settings_version = 2.39 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -3530,6 +3529,9 @@ def check(src, cond): if len(video_outpainting) >0: values += [video_outpainting] labels += ["Outpainting"] + if "G" in video_video_prompt_type and "V" in video_video_prompt_type: + values += [configs.get("denoising_strength",1)] + labels += ["Denoising Strength"] video_sample_solver = configs.get("sample_solver", "") if model_def.get("sample_solvers", None) is not None and len(video_sample_solver) > 0 : values += [video_sample_solver] @@ -6085,7 +6087,7 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m lset_name = get_lset_name(state, lset_name) if len(lset_name) == 0: gr.Info("Please choose a Lora Preset or Setting File in the list or create one") - return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, gr.update(), gr.update(), gr.update(), gr.update() + return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, gr.update(), gr.update(), gr.update(), gr.update(), gr.update() else: current_model_type = state["model_type"] ui_settings = get_current_model_settings(state) @@ -6109,7 +6111,8 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m return wizard_prompt_activated, loras_choices, loras_mult_choices, prompt, get_unique_id(), gr.update(), gr.update(), gr.update(), gr.update() else: - lset_path = os.path.join("profiles" if len(Path(lset_name).parts)>1 else get_lora_dir(current_model_type), lset_name) + accelerator_profile = len(Path(lset_name).parts)>1 + lset_path = os.path.join("profiles" if accelerator_profile else get_lora_dir(current_model_type), lset_name) configs, _ = get_settings_from_file(state,lset_path , True, True, True, min_settings_version=2.38, merge_loras = "merge after" if len(Path(lset_name).parts)<=1 else "merge before" ) if configs == None: @@ -6117,7 +6120,10 @@ def apply_lset(state, wizard_prompt_activated, lset_name, loras_choices, loras_m return [gr.update()] * 9 model_type = configs["model_type"] configs["lset_name"] = lset_name - gr.Info(f"Settings File '{lset_name}' has been applied") + if accelerator_profile: + gr.Info(f"Accelerator Profile '{os.path.splitext(os.path.basename(lset_name))[0]}' has been applied") + else: + gr.Info(f"Settings File '{os.path.basename(lset_name)}' has been applied") help = configs.get("help", None) if help is not None: gr.Info(help) if model_type == current_model_type: @@ -6703,7 +6709,6 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw if configs is None: return None, False - current_model_filename = state["model_filename"] current_model_type = state["model_type"] model_type = configs.get("model_type", None) @@ -6711,7 +6716,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw model_type = configs.get("base_model_type", None) if model_type == None: - model_filename = configs.get("model_filename", current_model_filename) + model_filename = configs.get("model_filename", "") model_type = get_model_type(model_filename) if model_type == None: model_type = current_model_type From 0b69c7146a6c1a0c0e6b1df2d31b1e05c5930a38 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 04:09:46 +1100 Subject: [PATCH 085/155] Initial commit --- README.md | 38 + requirements.txt | 3 + videoeditor/main.py | 1249 +++++++++++++++++++ videoeditor/plugins.py | 266 ++++ videoeditor/plugins/ai_frame_joiner/main.py | 308 +++++ 5 files changed, 1864 insertions(+) create mode 100644 videoeditor/main.py create mode 100644 videoeditor/plugins.py create mode 100644 videoeditor/plugins/ai_frame_joiner/main.py diff --git a/README.md b/README.md index fdaacf94f..2b409148a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,41 @@ +# Inline AI Video Editor + +A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provides a multi-track timeline, a media preview window, and basic clip manipulation capabilities, all wrapped in a dockable user interface. The editor is designed to be extensible through a plugin system. + +## Features + +- **Inline AI Video Generation With WAN2GP**: Select a region to join and it will bring up a desktop port of WAN2GP for you to generate a video inline using the start and end frames in the selected region. +- **Multi-Track Timeline**: Arrange video and audio clips on separate tracks. +- **Project Management**: Create, save, and load projects in a `.json` format. +- **Clip Operations**: + - Drag and drop clips to reposition them in the timeline. + - Split clips at the playhead. + - Create selection regions for advanced operations. + - Join/remove content within selected regions across all tracks. +- **Real-time Preview**: A video preview window with playback controls (Play, Pause, Stop, Frame-by-frame stepping). +- **Dynamic Track Management**: Add or remove video and audio tracks as needed. +- **FFmpeg Integration**: + - Handles video processing for frame extraction, playback, and exporting. + - **Automatic FFmpeg Downloader (Windows)**: Automatically downloads the necessary FFmpeg executables on first run if they are not found. +- **Extensible Plugin System**: Load custom plugins to add new features and dockable widgets. +- **Customizable UI**: Features a dockable interface with resizable panels for the video preview and timeline. +- **More coming soon.. + +## Installation + +**Follow the standard installation steps for WAN2GP below** + +**Run the server:** +```bash +python main.py +``` + +**Run the video editor:** +```bash +python videoeditor\main.py +``` + + # WanGP ----- diff --git a/requirements.txt b/requirements.txt index cfe61c57f..f68b37b56 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ + # Core AI stack diffusers==0.34.0 transformers==4.53.1 @@ -25,6 +26,7 @@ audio-separator==0.36.1 # UI & interaction gradio==5.29.0 +PyQt6 dashscope loguru @@ -57,6 +59,7 @@ ftfy piexif nvidia-ml-py misaki +GitPython # Optional / commented out # transformers==4.46.3 # for llamallava pre-patch diff --git a/videoeditor/main.py b/videoeditor/main.py new file mode 100644 index 000000000..717628c42 --- /dev/null +++ b/videoeditor/main.py @@ -0,0 +1,1249 @@ +import sys +import os +import uuid +import subprocess +import re +import json +import ffmpeg +from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, + QHBoxLayout, QPushButton, QFileDialog, QLabel, + QScrollArea, QFrame, QProgressBar, QDialog, + QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget) +from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction, + QPixmap, QImage) +from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread, + pyqtSignal, QTimer, QByteArray) + +from plugins import PluginManager, ManagePluginsDialog +import requests +import zipfile +from tqdm import tqdm + +def download_ffmpeg(): + if os.name != 'nt': return + exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe'] + if all(os.path.exists(e) for e in exes): return + api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest' + r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'}) + assets = r.json().get('assets', []) + zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None) + if not zip_asset: return + zip_url = zip_asset['browser_download_url'] + zip_name = zip_asset['name'] + with requests.get(zip_url, stream=True) as resp: + total = int(resp.headers.get('Content-Length', 0)) + with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar: + for chunk in resp.iter_content(chunk_size=8192): + f.write(chunk) + pbar.update(len(chunk)) + with zipfile.ZipFile(zip_name) as z: + for f in z.namelist(): + if f.endswith(tuple(exes)) and '/bin/' in f: + z.extract(f) + os.rename(f, os.path.basename(f)) + os.remove(zip_name) + +download_ffmpeg() + +class TimelineClip: + def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, group_id): + self.id = str(uuid.uuid4()) + self.source_path = source_path + self.timeline_start_sec = timeline_start_sec + self.clip_start_sec = clip_start_sec + self.duration_sec = duration_sec + self.track_index = track_index + self.track_type = track_type + self.group_id = group_id + + @property + def timeline_end_sec(self): + return self.timeline_start_sec + self.duration_sec + +class Timeline: + def __init__(self): + self.clips = [] + self.num_video_tracks = 1 + self.num_audio_tracks = 1 + + def add_clip(self, clip): + self.clips.append(clip) + self.clips.sort(key=lambda c: c.timeline_start_sec) + + def get_total_duration(self): + if not self.clips: return 0 + return max(c.timeline_end_sec for c in self.clips) + +class ExportWorker(QObject): + progress = pyqtSignal(int) + finished = pyqtSignal(str) + def __init__(self, ffmpeg_cmd, total_duration): + super().__init__() + self.ffmpeg_cmd = ffmpeg_cmd + self.total_duration_secs = total_duration + def run_export(self): + try: + process = subprocess.Popen(self.ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, encoding="utf-8") + time_pattern = re.compile(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})") + for line in iter(process.stdout.readline, ""): + match = time_pattern.search(line) + if match: + hours, minutes, seconds = [int(g) for g in match.groups()[:3]] + processed_time = hours * 3600 + minutes * 60 + seconds + if self.total_duration_secs > 0: + percentage = int((processed_time / self.total_duration_secs) * 100) + self.progress.emit(percentage) + process.stdout.close() + return_code = process.wait() + if return_code == 0: + self.progress.emit(100) + self.finished.emit("Export completed successfully!") + else: + self.finished.emit(f"Export failed! Check console for FFmpeg errors.") + except Exception as e: + self.finished.emit(f"An exception occurred during export: {e}") + +class TimelineWidget(QWidget): + PIXELS_PER_SECOND = 50 + TIMESCALE_HEIGHT = 30 + HEADER_WIDTH = 120 + TRACK_HEIGHT = 50 + AUDIO_TRACKS_SEPARATOR_Y = 15 + + split_requested = pyqtSignal(object) + playhead_moved = pyqtSignal(float) + split_region_requested = pyqtSignal(list) + split_all_regions_requested = pyqtSignal(list) + join_region_requested = pyqtSignal(list) + join_all_regions_requested = pyqtSignal(list) + add_track = pyqtSignal(str) + remove_track = pyqtSignal(str) + operation_finished = pyqtSignal() + context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent') + + def __init__(self, timeline_model, settings, parent=None): + super().__init__(parent) + self.timeline = timeline_model + self.settings = settings + self.playhead_pos_sec = 0.0 + self.setMinimumHeight(300) + self.setMouseTracking(True) + self.selection_regions = [] + self.dragging_clip = None + self.dragging_linked_clip = None + self.dragging_playhead = False + self.creating_selection_region = False + self.dragging_selection_region = None + self.drag_start_pos = QPoint() + self.drag_original_timeline_start = 0 + self.selection_drag_start_sec = 0.0 + self.drag_selection_start_values = None + + self.highlighted_track_info = None + self.add_video_track_btn_rect = QRect() + self.remove_video_track_btn_rect = QRect() + self.add_audio_track_btn_rect = QRect() + self.remove_audio_track_btn_rect = QRect() + + self.video_tracks_y_start = 0 + self.audio_tracks_y_start = 0 + + def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.PIXELS_PER_SECOND) + def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.PIXELS_PER_SECOND if x > self.HEADER_WIDTH else 0.0 + + def paintEvent(self, event): + painter = QPainter(self) + painter.setRenderHint(QPainter.RenderHint.Antialiasing) + painter.fillRect(self.rect(), QColor("#333")) + + self.draw_headers(painter) + self.draw_timescale(painter) + self.draw_tracks_and_clips(painter) + self.draw_selections(painter) + self.draw_playhead(painter) + + total_width = self.sec_to_x(self.timeline.get_total_duration()) + 100 + total_height = self.calculate_total_height() + self.setMinimumSize(max(self.parent().width(), total_width), total_height) + + def calculate_total_height(self): + video_tracks_height = (self.timeline.num_video_tracks + 1) * self.TRACK_HEIGHT + audio_tracks_height = (self.timeline.num_audio_tracks + 1) * self.TRACK_HEIGHT + return self.TIMESCALE_HEIGHT + video_tracks_height + self.AUDIO_TRACKS_SEPARATOR_Y + audio_tracks_height + 20 + + def draw_headers(self, painter): + painter.save() + painter.setPen(QColor("#AAA")) + header_font = QFont("Arial", 9, QFont.Weight.Bold) + button_font = QFont("Arial", 8) + + y_cursor = self.TIMESCALE_HEIGHT + + rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(rect, QColor("#3a3a3a")) + painter.drawRect(rect) + self.add_video_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) + painter.setFont(button_font) + painter.fillRect(self.add_video_track_btn_rect, QColor("#454")) + painter.drawText(self.add_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)") + y_cursor += self.TRACK_HEIGHT + self.video_tracks_y_start = y_cursor + + for i in range(self.timeline.num_video_tracks): + track_number = self.timeline.num_video_tracks - i + rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(rect, QColor("#444")) + painter.drawRect(rect) + painter.setFont(header_font) + painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Video {track_number}") + + if track_number == self.timeline.num_video_tracks and self.timeline.num_video_tracks > 1: + self.remove_video_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20) + painter.setFont(button_font) + painter.fillRect(self.remove_video_track_btn_rect, QColor("#833")) + painter.drawText(self.remove_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-") + y_cursor += self.TRACK_HEIGHT + + y_cursor += self.AUDIO_TRACKS_SEPARATOR_Y + + self.audio_tracks_y_start = y_cursor + for i in range(self.timeline.num_audio_tracks): + track_number = i + 1 + rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(rect, QColor("#444")) + painter.drawRect(rect) + painter.setFont(header_font) + painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}") + + if track_number == self.timeline.num_audio_tracks and self.timeline.num_audio_tracks > 1: + self.remove_audio_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20) + painter.setFont(button_font) + painter.fillRect(self.remove_audio_track_btn_rect, QColor("#833")) + painter.drawText(self.remove_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-") + y_cursor += self.TRACK_HEIGHT + + rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(rect, QColor("#3a3a3a")) + painter.drawRect(rect) + self.add_audio_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) + painter.setFont(button_font) + painter.fillRect(self.add_audio_track_btn_rect, QColor("#454")) + painter.drawText(self.add_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)") + + painter.restore() + + def draw_timescale(self, painter): + painter.save() + painter.setPen(QColor("#AAA")) + painter.setFont(QFont("Arial", 8)) + font_metrics = QFontMetrics(painter.font()) + + painter.fillRect(QRect(self.HEADER_WIDTH, 0, self.width() - self.HEADER_WIDTH, self.TIMESCALE_HEIGHT), QColor("#222")) + painter.drawLine(self.HEADER_WIDTH, self.TIMESCALE_HEIGHT - 1, self.width(), self.TIMESCALE_HEIGHT - 1) + + max_time_to_draw = self.x_to_sec(self.width()) + major_interval_sec = 5 + minor_interval_sec = 1 + + for t_sec in range(int(max_time_to_draw) + 2): + x = self.sec_to_x(t_sec) + + if t_sec % major_interval_sec == 0: + painter.drawLine(x, self.TIMESCALE_HEIGHT - 10, x, self.TIMESCALE_HEIGHT - 1) + label = f"{t_sec}s" + label_width = font_metrics.horizontalAdvance(label) + painter.drawText(x - label_width // 2, self.TIMESCALE_HEIGHT - 12, label) + elif t_sec % minor_interval_sec == 0: + painter.drawLine(x, self.TIMESCALE_HEIGHT - 5, x, self.TIMESCALE_HEIGHT - 1) + painter.restore() + + def get_clip_rect(self, clip): + if clip.track_type == 'video': + visual_index = self.timeline.num_video_tracks - clip.track_index + y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT + else: + visual_index = clip.track_index - 1 + y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT + + x = self.sec_to_x(clip.timeline_start_sec) + w = int(clip.duration_sec * self.PIXELS_PER_SECOND) + clip_height = self.TRACK_HEIGHT - 10 + y += (self.TRACK_HEIGHT - clip_height) / 2 + return QRectF(x, y, w, clip_height) + + def draw_tracks_and_clips(self, painter): + painter.save() + y_cursor = self.video_tracks_y_start + for i in range(self.timeline.num_video_tracks): + rect = QRect(self.HEADER_WIDTH, y_cursor, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(rect, QColor("#444") if i % 2 == 0 else QColor("#404040")) + y_cursor += self.TRACK_HEIGHT + + y_cursor = self.audio_tracks_y_start + for i in range(self.timeline.num_audio_tracks): + rect = QRect(self.HEADER_WIDTH, y_cursor, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(rect, QColor("#444") if i % 2 == 0 else QColor("#404040")) + y_cursor += self.TRACK_HEIGHT + + if self.highlighted_track_info: + track_type, track_index = self.highlighted_track_info + y = -1 + if track_type == 'video' and track_index <= self.timeline.num_video_tracks: + visual_index = self.timeline.num_video_tracks - track_index + y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT + elif track_type == 'audio' and track_index <= self.timeline.num_audio_tracks: + visual_index = track_index - 1 + y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT + + if y != -1: + highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(highlight_rect, QColor(255, 255, 0, 40)) + + # Draw clips + for clip in self.timeline.clips: + clip_rect = self.get_clip_rect(clip) + base_color = QColor("#46A") if clip.track_type == 'video' else QColor("#48C") + color = QColor("#5A9") if self.dragging_clip and self.dragging_clip.id == clip.id else base_color + painter.fillRect(clip_rect, color) + painter.setPen(QPen(QColor("#FFF"), 1)) + font = QFont("Arial", 10) + painter.setFont(font) + text = os.path.basename(clip.source_path) + font_metrics = QFontMetrics(font) + text_width = font_metrics.horizontalAdvance(text) + if text_width > clip_rect.width() - 10: text = font_metrics.elidedText(text, Qt.TextElideMode.ElideRight, int(clip_rect.width() - 10)) + painter.drawText(QPoint(int(clip_rect.left() + 5), int(clip_rect.center().y() + 5)), text) + painter.restore() + + def draw_selections(self, painter): + for start_sec, end_sec in self.selection_regions: + x = self.sec_to_x(start_sec) + w = int((end_sec - start_sec) * self.PIXELS_PER_SECOND) + selection_rect = QRectF(x, self.TIMESCALE_HEIGHT, w, self.height() - self.TIMESCALE_HEIGHT) + painter.fillRect(selection_rect, QColor(100, 100, 255, 80)) + painter.setPen(QColor(150, 150, 255, 150)) + painter.drawRect(selection_rect) + + def draw_playhead(self, painter): + playhead_x = self.sec_to_x(self.playhead_pos_sec) + painter.setPen(QPen(QColor("red"), 2)) + painter.drawLine(playhead_x, 0, playhead_x, self.height()) + + def y_to_track_info(self, y): + video_tracks_end_y = self.video_tracks_y_start + self.timeline.num_video_tracks * self.TRACK_HEIGHT + if self.video_tracks_y_start <= y < video_tracks_end_y: + visual_index = (y - self.video_tracks_y_start) // self.TRACK_HEIGHT + track_index = self.timeline.num_video_tracks - visual_index + return ('video', track_index) + + audio_tracks_end_y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT + if self.audio_tracks_y_start <= y < audio_tracks_end_y: + visual_index = (y - self.audio_tracks_y_start) // self.TRACK_HEIGHT + track_index = visual_index + 1 + return ('audio', track_index) + return None + + def get_region_at_pos(self, pos: QPoint): + if pos.y() <= self.TIMESCALE_HEIGHT or pos.x() <= self.HEADER_WIDTH: + return None + + clicked_sec = self.x_to_sec(pos.x()) + for region in reversed(self.selection_regions): + if region[0] <= clicked_sec <= region[1]: + return region + return None + + def mousePressEvent(self, event: QMouseEvent): + if event.button() == Qt.MouseButton.LeftButton: + if event.pos().x() < self.HEADER_WIDTH: + # Click in headers area + if self.add_video_track_btn_rect.contains(event.pos()): self.add_track.emit('video') + elif self.remove_video_track_btn_rect.contains(event.pos()): self.remove_track.emit('video') + elif self.add_audio_track_btn_rect.contains(event.pos()): self.add_track.emit('audio') + elif self.remove_audio_track_btn_rect.contains(event.pos()): self.remove_track.emit('audio') + return + + self.dragging_clip = None + self.dragging_linked_clip = None + self.dragging_playhead = False + self.dragging_selection_region = None + self.creating_selection_region = False + + for clip in reversed(self.timeline.clips): + clip_rect = self.get_clip_rect(clip) + if clip_rect.contains(QPointF(event.pos())): + self.dragging_clip = clip + self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id), None) + self.drag_start_pos = event.pos() + self.drag_original_timeline_start = clip.timeline_start_sec + break + + if not self.dragging_clip: + region_to_drag = self.get_region_at_pos(event.pos()) + if region_to_drag: + self.dragging_selection_region = region_to_drag + self.drag_start_pos = event.pos() + self.drag_selection_start_values = tuple(region_to_drag) + else: + is_on_timescale = event.pos().y() <= self.TIMESCALE_HEIGHT + is_in_track_area = event.pos().y() > self.TIMESCALE_HEIGHT and event.pos().x() > self.HEADER_WIDTH + + if is_in_track_area: + self.creating_selection_region = True + self.selection_drag_start_sec = self.x_to_sec(event.pos().x()) + self.selection_regions.append([self.selection_drag_start_sec, self.selection_drag_start_sec]) + elif is_on_timescale: + self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x())) + self.playhead_moved.emit(self.playhead_pos_sec) + self.dragging_playhead = True + + self.update() + + def mouseMoveEvent(self, event: QMouseEvent): + if self.creating_selection_region: + current_sec = self.x_to_sec(event.pos().x()) + start = min(self.selection_drag_start_sec, current_sec) + end = max(self.selection_drag_start_sec, current_sec) + self.selection_regions[-1] = [start, end] + self.update() + return + + if self.dragging_selection_region: + delta_x = event.pos().x() - self.drag_start_pos.x() + time_delta = delta_x / self.PIXELS_PER_SECOND + + original_start, original_end = self.drag_selection_start_values + duration = original_end - original_start + new_start = max(0, original_start + time_delta) + + self.dragging_selection_region[0] = new_start + self.dragging_selection_region[1] = new_start + duration + + self.update() + return + + if self.dragging_playhead: + self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x())) + self.playhead_moved.emit(self.playhead_pos_sec) + self.update() + elif self.dragging_clip: + self.highlighted_track_info = None + new_track_info = self.y_to_track_info(event.pos().y()) + if new_track_info: + new_track_type, new_track_index = new_track_info + if new_track_type == self.dragging_clip.track_type: + self.dragging_clip.track_index = new_track_index + if self.dragging_linked_clip: + self.dragging_linked_clip.track_index = new_track_index + if self.dragging_clip.track_type == 'video': + if new_track_index > self.timeline.num_audio_tracks: + self.timeline.num_audio_tracks = new_track_index + self.highlighted_track_info = ('audio', new_track_index) + else: + if new_track_index > self.timeline.num_video_tracks: + self.timeline.num_video_tracks = new_track_index + self.highlighted_track_info = ('video', new_track_index) + + delta_x = event.pos().x() - self.drag_start_pos.x() + time_delta = delta_x / self.PIXELS_PER_SECOND + new_start_time = self.drag_original_timeline_start + time_delta + + for other_clip in self.timeline.clips: + if other_clip.id == self.dragging_clip.id: continue + if self.dragging_linked_clip and other_clip.id == self.dragging_linked_clip.id: continue + if (other_clip.track_type != self.dragging_clip.track_type or + other_clip.track_index != self.dragging_clip.track_index): + continue + + is_overlapping = (new_start_time < other_clip.timeline_end_sec and + new_start_time + self.dragging_clip.duration_sec > other_clip.timeline_start_sec) + + if is_overlapping: + if time_delta > 0: new_start_time = other_clip.timeline_start_sec - self.dragging_clip.duration_sec + else: new_start_time = other_clip.timeline_end_sec + break + + final_start_time = max(0, new_start_time) + self.dragging_clip.timeline_start_sec = final_start_time + if self.dragging_linked_clip: + self.dragging_linked_clip.timeline_start_sec = final_start_time + + self.update() + + def mouseReleaseEvent(self, event: QMouseEvent): + if event.button() == Qt.MouseButton.LeftButton: + if self.creating_selection_region: + self.creating_selection_region = False + if self.selection_regions: + start, end = self.selection_regions[-1] + if (end - start) * self.PIXELS_PER_SECOND < 2: + self.selection_regions.pop() + + if self.dragging_selection_region: + self.dragging_selection_region = None + self.drag_selection_start_values = None + + self.dragging_playhead = False + if self.dragging_clip: + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.highlighted_track_info = None + self.operation_finished.emit() + self.dragging_clip = None + self.dragging_linked_clip = None + + self.update() + + def contextMenuEvent(self, event: 'QContextMenuEvent'): + menu = QMenu(self) + + region_at_pos = self.get_region_at_pos(event.pos()) + if region_at_pos: + split_this_action = menu.addAction("Split This Region") + split_all_action = menu.addAction("Split All Regions") + join_this_action = menu.addAction("Join This Region") + join_all_action = menu.addAction("Join All Regions") + menu.addSeparator() + clear_this_action = menu.addAction("Clear This Region") + clear_all_action = menu.addAction("Clear All Regions") + split_this_action.triggered.connect(lambda: self.split_region_requested.emit(region_at_pos)) + split_all_action.triggered.connect(lambda: self.split_all_regions_requested.emit(self.selection_regions)) + join_this_action.triggered.connect(lambda: self.join_region_requested.emit(region_at_pos)) + join_all_action.triggered.connect(lambda: self.join_all_regions_requested.emit(self.selection_regions)) + clear_this_action.triggered.connect(lambda: self.clear_region(region_at_pos)) + clear_all_action.triggered.connect(self.clear_all_regions) + + clip_at_pos = None + for clip in self.timeline.clips: + if self.get_clip_rect(clip).contains(QPointF(event.pos())): + clip_at_pos = clip + break + + if clip_at_pos: + if not menu.isEmpty(): menu.addSeparator() + split_action = menu.addAction("Split Clip") + playhead_time = self.playhead_pos_sec + is_playhead_over_clip = (clip_at_pos.timeline_start_sec < playhead_time < clip_at_pos.timeline_end_sec) + split_action.setEnabled(is_playhead_over_clip) + split_action.triggered.connect(lambda: self.split_requested.emit(clip_at_pos)) + + self.context_menu_requested.emit(menu, event) + + if not menu.isEmpty(): + menu.exec(self.mapToGlobal(event.pos())) + + def clear_region(self, region_to_clear): + if region_to_clear in self.selection_regions: + self.selection_regions.remove(region_to_clear) + self.update() + + def clear_all_regions(self): + self.selection_regions.clear() + self.update() + +class SettingsDialog(QDialog): + def __init__(self, parent_settings, parent=None): + super().__init__(parent) + self.setWindowTitle("Settings") + self.setMinimumWidth(350) + layout = QVBoxLayout(self) + self.seek_anywhere_checkbox = QCheckBox("Allow seeking by clicking anywhere on the timeline") + self.seek_anywhere_checkbox.setChecked(parent_settings.get("allow_seek_anywhere", False)) + layout.addWidget(self.seek_anywhere_checkbox) + button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + button_box.accepted.connect(self.accept) + button_box.rejected.connect(self.reject) + layout.addWidget(button_box) + + def get_settings(self): + return {"allow_seek_anywhere": self.seek_anywhere_checkbox.isChecked()} + +class MainWindow(QMainWindow): + def __init__(self, project_to_load=None): + super().__init__() + self.setWindowTitle("Inline AI Video Editor") + self.setGeometry(100, 100, 1200, 800) + self.setDockOptions(QMainWindow.DockOption.AnimatedDocks | QMainWindow.DockOption.AllowNestedDocks) + + self.timeline = Timeline() + self.export_thread = None + self.export_worker = None + self.current_project_path = None + self.settings = {} + self.settings_file = "settings.json" + self.is_shutting_down = False + self._load_settings() + + self.plugin_manager = PluginManager(self) + self.plugin_manager.discover_and_load_plugins() + + self.project_fps = 25.0 + self.project_width = 1280 + self.project_height = 720 + self.playback_timer = QTimer(self) + self.playback_process = None + self.playback_clip = None + + self.splitter = QSplitter(Qt.Orientation.Vertical) + + self.preview_widget = QLabel() + self.preview_widget.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.preview_widget.setMinimumSize(640, 360) + self.preview_widget.setFrameShape(QFrame.Shape.Box) + self.preview_widget.setStyleSheet("background-color: black; color: white;") + self.splitter.addWidget(self.preview_widget) + + self.timeline_widget = TimelineWidget(self.timeline, self.settings) + self.timeline_scroll_area = QScrollArea() + self.timeline_scroll_area.setWidgetResizable(True) + self.timeline_scroll_area.setWidget(self.timeline_widget) + self.timeline_scroll_area.setFrameShape(QFrame.Shape.NoFrame) + self.timeline_scroll_area.setMinimumHeight(250) + self.splitter.addWidget(self.timeline_scroll_area) + + container_widget = QWidget() + main_layout = QVBoxLayout(container_widget) + main_layout.setContentsMargins(0,0,0,0) + main_layout.addWidget(self.splitter, 1) + + controls_widget = QWidget() + controls_layout = QHBoxLayout(controls_widget) + controls_layout.setContentsMargins(0, 5, 0, 5) + self.play_pause_button = QPushButton("Play") + self.stop_button = QPushButton("Stop") + self.frame_back_button = QPushButton("<") + self.frame_forward_button = QPushButton(">") + controls_layout.addStretch() + controls_layout.addWidget(self.frame_back_button) + controls_layout.addWidget(self.play_pause_button) + controls_layout.addWidget(self.stop_button) + controls_layout.addWidget(self.frame_forward_button) + controls_layout.addStretch() + main_layout.addWidget(controls_widget) + + status_layout = QHBoxLayout() + self.status_label = QLabel("Ready. Create or open a project from the File menu.") + self.progress_bar = QProgressBar() + self.progress_bar.setVisible(False) + self.progress_bar.setTextVisible(True) + self.progress_bar.setRange(0, 100) + status_layout.addWidget(self.status_label, 1) + status_layout.addWidget(self.progress_bar, 1) + main_layout.addLayout(status_layout) + + self.setCentralWidget(container_widget) + + self.managed_widgets = { + 'preview': {'widget': self.preview_widget, 'name': 'Video Preview', 'action': None}, + 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None} + } + self.plugin_menu_actions = {} + self.windows_menu = None + self._create_menu_bar() + + self.splitter_save_timer = QTimer(self) + self.splitter_save_timer.setSingleShot(True) + self.splitter_save_timer.timeout.connect(self._save_settings) + self.splitter.splitterMoved.connect(self.on_splitter_moved) + + self.timeline_widget.split_requested.connect(self.split_clip_at_playhead) + self.timeline_widget.playhead_moved.connect(self.seek_preview) + self.timeline_widget.split_region_requested.connect(self.on_split_region) + self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions) + self.timeline_widget.join_region_requested.connect(self.on_join_region) + self.timeline_widget.join_all_regions_requested.connect(self.on_join_all_regions) + self.timeline_widget.add_track.connect(self.add_track) + self.timeline_widget.remove_track.connect(self.remove_track) + self.timeline_widget.operation_finished.connect(self.prune_empty_tracks) + + self.play_pause_button.clicked.connect(self.toggle_playback) + self.stop_button.clicked.connect(self.stop_playback) + self.frame_back_button.clicked.connect(lambda: self.step_frame(-1)) + self.frame_forward_button.clicked.connect(lambda: self.step_frame(1)) + self.playback_timer.timeout.connect(self.advance_playback_frame) + + self.plugin_manager.load_enabled_plugins_from_settings(self.settings.get("enabled_plugins", [])) + self._apply_loaded_settings() + self.seek_preview(0) + + if not self.settings_file_was_loaded: self._save_settings() + if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load)) + + def prune_empty_tracks(self): + pruned_something = False + + while self.timeline.num_video_tracks > 1: + highest_track_index = self.timeline.num_video_tracks + is_track_occupied = any(c for c in self.timeline.clips + if c.track_type == 'video' and c.track_index == highest_track_index) + if is_track_occupied: + break + else: + self.timeline.num_video_tracks -= 1 + pruned_something = True + + while self.timeline.num_audio_tracks > 1: + highest_track_index = self.timeline.num_audio_tracks + is_track_occupied = any(c for c in self.timeline.clips + if c.track_type == 'audio' and c.track_index == highest_track_index) + if is_track_occupied: + break + else: + self.timeline.num_audio_tracks -= 1 + pruned_something = True + + if pruned_something: + self.timeline_widget.update() + + + def add_track(self, track_type): + if track_type == 'video': + self.timeline.num_video_tracks += 1 + elif track_type == 'audio': + self.timeline.num_audio_tracks += 1 + self.timeline_widget.update() + + def remove_track(self, track_type): + if track_type == 'video' and self.timeline.num_video_tracks > 1: + self.timeline.num_video_tracks -= 1 + elif track_type == 'audio' and self.timeline.num_audio_tracks > 1: + self.timeline.num_audio_tracks -= 1 + self.timeline_widget.update() + + def _create_menu_bar(self): + menu_bar = self.menuBar() + file_menu = menu_bar.addMenu("&File") + new_action = QAction("&New Project", self); new_action.triggered.connect(self.new_project) + open_action = QAction("&Open Project...", self); open_action.triggered.connect(self.open_project) + save_action = QAction("&Save Project As...", self); save_action.triggered.connect(self.save_project_as) + add_video_action = QAction("&Add Video...", self); add_video_action.triggered.connect(self.add_video_clip) + export_action = QAction("&Export Video...", self); export_action.triggered.connect(self.export_video) + settings_action = QAction("Se&ttings...", self); settings_action.triggered.connect(self.open_settings_dialog) + exit_action = QAction("E&xit", self); exit_action.triggered.connect(self.close) + file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addAction(save_action) + file_menu.addSeparator(); file_menu.addAction(add_video_action); file_menu.addAction(export_action) + file_menu.addSeparator(); file_menu.addAction(settings_action); file_menu.addSeparator(); file_menu.addAction(exit_action) + + edit_menu = menu_bar.addMenu("&Edit") + split_action = QAction("Split Clip at Playhead", self); split_action.triggered.connect(self.split_clip_at_playhead) + edit_menu.addAction(split_action) + + plugins_menu = menu_bar.addMenu("&Plugins") + for name, data in self.plugin_manager.plugins.items(): + plugin_action = QAction(name, self, checkable=True) + plugin_action.setChecked(data['enabled']) + plugin_action.toggled.connect(lambda checked, n=name: self.toggle_plugin(n, checked)) + plugins_menu.addAction(plugin_action) + self.plugin_menu_actions[name] = plugin_action + + plugins_menu.addSeparator() + manage_action = QAction("Manage plugins...", self) + manage_action.triggered.connect(self.open_manage_plugins_dialog) + plugins_menu.addAction(manage_action) + + self.windows_menu = menu_bar.addMenu("&Windows") + for key, data in self.managed_widgets.items(): + action = QAction(data['name'], self, checkable=True) + action.toggled.connect(lambda checked, k=key: self.toggle_widget_visibility(k, checked)) + data['action'] = action + self.windows_menu.addAction(action) + + def _start_playback_stream_at(self, time_sec): + self._stop_playback_stream() + clip = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) + if not clip: return + self.playback_clip = clip + clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec + try: + args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time).output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile()) + self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + except Exception as e: + print(f"Failed to start playback stream: {e}"); self._stop_playback_stream() + + def _stop_playback_stream(self): + if self.playback_process: + if self.playback_process.poll() is None: + self.playback_process.terminate() + try: self.playback_process.wait(timeout=0.5) + except subprocess.TimeoutExpired: self.playback_process.kill(); self.playback_process.wait() + self.playback_process = None + self.playback_clip = None + + def _set_project_properties_from_clip(self, source_path): + try: + probe = ffmpeg.probe(source_path) + video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) + if video_stream: + self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height']) + if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0': + num, den = map(int, video_stream['r_frame_rate'].split('/')) + if den > 0: self.project_fps = num / den + print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS") + return True + except Exception as e: print(f"Could not probe for project properties: {e}") + return False + + def get_frame_data_at_time(self, time_sec): + clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) + if not clip_at_time: + return (None, 0, 0) + try: + clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec + out, _ = ( + ffmpeg + .input(clip_at_time.source_path, ss=clip_time) + .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + return (out, self.project_width, self.project_height) + except ffmpeg.Error as e: + print(f"Error extracting frame data: {e.stderr}") + return (None, 0, 0) + + def get_frame_at_time(self, time_sec): + clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) + black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) + if not clip_at_time: return black_pixmap + try: + clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec + out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time).output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24').run(capture_stdout=True, quiet=True)) + image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888) + return QPixmap.fromImage(image) + except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap + + def seek_preview(self, time_sec): + self._stop_playback_stream() + self.timeline_widget.playhead_pos_sec = time_sec + self.timeline_widget.update() + frame_pixmap = self.get_frame_at_time(time_sec) + if frame_pixmap: + scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + self.preview_widget.setPixmap(scaled_pixmap) + + def toggle_playback(self): + if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play") + else: + if not self.timeline.clips: return + if self.timeline_widget.playhead_pos_sec >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_sec = 0.0 + self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause") + + def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0.0) + def step_frame(self, direction): + if not self.timeline.clips: return + self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream() + frame_duration = 1.0 / self.project_fps + new_time = self.timeline_widget.playhead_pos_sec + (direction * frame_duration) + self.seek_preview(max(0, min(new_time, self.timeline.get_total_duration()))) + + def advance_playback_frame(self): + frame_duration = 1.0 / self.project_fps + new_time = self.timeline_widget.playhead_pos_sec + frame_duration + if new_time > self.timeline.get_total_duration(): self.stop_playback(); return + self.timeline_widget.playhead_pos_sec = new_time; self.timeline_widget.update() + clip_at_new_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec), None) + if not clip_at_new_time: + self._stop_playback_stream() + black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) + scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + self.preview_widget.setPixmap(scaled_pixmap); return + if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: self._start_playback_stream_at(new_time) + if self.playback_process: + frame_size = self.project_width * self.project_height * 3 + frame_bytes = self.playback_process.stdout.read(frame_size) + if len(frame_bytes) == frame_size: + image = QImage(frame_bytes, self.project_width, self.project_height, QImage.Format.Format_RGB888) + pixmap = QPixmap.fromImage(image) + scaled_pixmap = pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + self.preview_widget.setPixmap(scaled_pixmap) + else: self._stop_playback_stream() + + def _load_settings(self): + self.settings_file_was_loaded = False + defaults = {"allow_seek_anywhere": False, "window_visibility": {"preview": True, "timeline": True}, "splitter_state": None, "enabled_plugins": []} + if os.path.exists(self.settings_file): + try: + with open(self.settings_file, "r") as f: self.settings = json.load(f) + self.settings_file_was_loaded = True + for key, value in defaults.items(): + if key not in self.settings: self.settings[key] = value + except (json.JSONDecodeError, IOError): self.settings = defaults + else: self.settings = defaults + + def _save_settings(self): + self.settings["splitter_state"] = self.splitter.saveState().toHex().data().decode('ascii') + visibility_to_save = { + key: data['action'].isChecked() + for key, data in self.managed_widgets.items() if data.get('action') + } + + self.settings["window_visibility"] = visibility_to_save + self.settings['enabled_plugins'] = self.plugin_manager.get_enabled_plugin_names() + try: + with open(self.settings_file, "w") as f: + json.dump(self.settings, f, indent=4) + except IOError as e: + print(f"Error saving settings: {e}") + + def _apply_loaded_settings(self): + visibility_settings = self.settings.get("window_visibility", {}) + for key, data in self.managed_widgets.items(): + if data.get('plugin'): + continue + + is_visible = visibility_settings.get(key, True) + data['widget'].setVisible(is_visible) + if data['action']: data['action'].setChecked(is_visible) + + splitter_state = self.settings.get("splitter_state") + if splitter_state: self.splitter.restoreState(QByteArray.fromHex(splitter_state.encode('ascii'))) + + def on_splitter_moved(self, pos, index): self.splitter_save_timer.start(500) + + def toggle_widget_visibility(self, key, checked): + if self.is_shutting_down: + return + if key in self.managed_widgets: + self.managed_widgets[key]['widget'].setVisible(checked) + self._save_settings() + + def open_settings_dialog(self): + dialog = SettingsDialog(self.settings, self) + if dialog.exec() == QDialog.DialogCode.Accepted: + self.settings.update(dialog.get_settings()); self._save_settings(); self.status_label.setText("Settings updated.") + + def new_project(self): + self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1 + self.current_project_path = None; self.stop_playback(); self.timeline_widget.update() + self.status_label.setText("New project created. Add video clips to begin.") + + def save_project_as(self): + path, _ = QFileDialog.getSaveFileName(self, "Save Project", "", "JSON Project Files (*.json)") + if not path: return + project_data = { + "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "group_id": c.group_id} for c in self.timeline.clips], + "settings": {"num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks} + } + try: + with open(path, "w") as f: json.dump(project_data, f, indent=4) + self.current_project_path = path; self.status_label.setText(f"Project saved to {os.path.basename(path)}") + except Exception as e: self.status_label.setText(f"Error saving project: {e}") + + def open_project(self): + path, _ = QFileDialog.getOpenFileName(self, "Open Project", "", "JSON Project Files (*.json)") + if path: self._load_project_from_path(path) + + def _load_project_from_path(self, path): + try: + with open(path, "r") as f: project_data = json.load(f) + self.timeline.clips.clear() + for clip_data in project_data["clips"]: + if not os.path.exists(clip_data["source_path"]): + self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.timeline.clips.clear(); self.timeline_widget.update(); return + self.timeline.add_clip(TimelineClip(**clip_data)) + + project_settings = project_data.get("settings", {}) + self.timeline.num_video_tracks = project_settings.get("num_video_tracks", 1) + self.timeline.num_audio_tracks = project_settings.get("num_audio_tracks", 1) + + self.current_project_path = path + if self.timeline.clips: self._set_project_properties_from_clip(self.timeline.clips[0].source_path) + self.prune_empty_tracks() + self.timeline_widget.update(); self.stop_playback() + self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.") + except Exception as e: self.status_label.setText(f"Error opening project: {e}") + + def add_video_clip(self): + file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.mov *.avi)") + if not file_path: return + if not self.timeline.clips: + if not self._set_project_properties_from_clip(file_path): self.status_label.setText("Error: Could not determine video properties from file."); return + try: + self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents() + probe = ffmpeg.probe(file_path) + video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) + if not video_stream: raise ValueError("No video stream found.") + duration = float(video_stream.get('duration', probe['format'].get('duration', 0))) + has_audio = any(s['codec_type'] == 'audio' for s in probe['streams']) + timeline_start = self.timeline.get_total_duration() + + self._add_clip_to_timeline( + source_path=file_path, + timeline_start_sec=timeline_start, + duration_sec=duration, + clip_start_sec=0, + video_track_index=1, + audio_track_index=1 if has_audio else None + ) + + self.timeline_widget.update(); self.seek_preview(self.timeline_widget.playhead_pos_sec) + except Exception as e: self.status_label.setText(f"Error adding file: {e}") + + def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, clip_start_sec=0.0, video_track_index=None, audio_track_index=None): + group_id = str(uuid.uuid4()) + + if video_track_index is not None: + video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', group_id) + self.timeline.add_clip(video_clip) + + if audio_track_index is not None: + audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', group_id) + self.timeline.add_clip(audio_clip) + + self.timeline_widget.update() + self.status_label.setText(f"Added {os.path.basename(source_path)}.") + + def _split_at_time(self, clip_to_split, time_sec, new_group_id=None): + if not (clip_to_split.timeline_start_sec < time_sec < clip_to_split.timeline_end_sec): return False + split_point = time_sec - clip_to_split.timeline_start_sec + orig_dur = clip_to_split.duration_sec + + group_id_for_new_clip = new_group_id if new_group_id is not None else clip_to_split.group_id + + new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, group_id_for_new_clip) + clip_to_split.duration_sec = split_point + self.timeline.add_clip(new_clip) + return True + + def split_clip_at_playhead(self, clip_to_split=None): + playhead_time = self.timeline_widget.playhead_pos_sec + if not clip_to_split: + clip_to_split = next((c for c in self.timeline.clips if c.timeline_start_sec < playhead_time < c.timeline_end_sec), None) + + if not clip_to_split: + self.status_label.setText("Playhead is not over a clip to split.") + return + + linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_split.group_id and c.id != clip_to_split.id), None) + + new_right_side_group_id = str(uuid.uuid4()) + + split1 = self._split_at_time(clip_to_split, playhead_time, new_group_id=new_right_side_group_id) + split2 = True + if linked_clip: + split2 = self._split_at_time(linked_clip, playhead_time, new_group_id=new_right_side_group_id) + + if split1 and split2: + self.timeline_widget.update() + self.status_label.setText("Clip split.") + else: + self.status_label.setText("Failed to split clip.") + + def on_split_region(self, region): + start_sec, end_sec = region; + clips = list(self.timeline.clips) + for clip in clips: self._split_at_time(clip, end_sec) + for clip in clips: self._split_at_time(clip, start_sec) + self.timeline_widget.clear_region(region); self.timeline_widget.update(); self.status_label.setText("Region split.") + + def on_split_all_regions(self, regions): + split_points = set() + for start, end in regions: split_points.add(start); split_points.add(end) + for point in sorted(list(split_points)): + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} + + for clip in list(self.timeline.clips): + if clip.group_id in new_group_ids: + self._split_at_time(clip, point, new_group_ids[clip.group_id]) + + self.timeline_widget.clear_all_regions(); self.timeline_widget.update(); self.status_label.setText("All regions split.") + + def on_join_region(self, region): + start_sec, end_sec = region; duration_to_remove = end_sec - start_sec + if duration_to_remove <= 0.01: return + + for point in [start_sec, end_sec]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} + for clip in list(self.timeline.clips): + if clip.group_id in new_group_ids: + self._split_at_time(clip, point, new_group_ids[clip.group_id]) + + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + for clip in clips_to_remove: self.timeline.clips.remove(clip) + for clip in self.timeline.clips: + if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec); self.timeline_widget.clear_region(region) + self.timeline_widget.update(); self.status_label.setText("Region joined (content removed).") + self.prune_empty_tracks() + + def on_join_all_regions(self, regions): + for region in sorted(regions, key=lambda r: r[0], reverse=True): + start_sec, end_sec = region; duration_to_remove = end_sec - start_sec + if duration_to_remove <= 0.01: continue + + for point in [start_sec, end_sec]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} + for clip in list(self.timeline.clips): + if clip.group_id in new_group_ids: + self._split_at_time(clip, point, new_group_ids[clip.group_id]) + + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + for clip in clips_to_remove: + try: self.timeline.clips.remove(clip) + except ValueError: pass + for clip in self.timeline.clips: + if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove + + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec); self.timeline_widget.clear_all_regions() + self.timeline_widget.update(); self.status_label.setText("All regions joined.") + self.prune_empty_tracks() + + def export_video(self): + if not self.timeline.clips: self.status_label.setText("Timeline is empty."); return + output_path, _ = QFileDialog.getSaveFileName(self, "Save Video As", "", "MP4 Files (*.mp4)") + if not output_path: return + + w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration() + sample_rate, channel_layout = '44100', 'stereo' + + input_files = {clip.source_path: ffmpeg.input(clip.source_path) for clip in self.timeline.clips} + + last_video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi') + for i in range(self.timeline.num_video_tracks): + track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec) + if not track_clips: continue + + track_segments = [] + last_end = track_clips[0].timeline_start_sec + + for clip in track_clips: + gap = clip.timeline_start_sec - last_end + if gap > 0.01: + track_segments.append(ffmpeg.input(f'color=c=black@0.0:s={w}x{h}:r={fr_str}:d={gap}', f='lavfi').filter('format', pix_fmts='rgba')) + + v_seg = (input_files[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS') + .filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba')) + track_segments.append(v_seg) + last_end = clip.timeline_end_sec + + track_stream = ffmpeg.concat(*track_segments, v=1, a=0).filter('setpts', f'PTS-STARTPTS+{track_clips[0].timeline_start_sec}/TB') + last_video_stream = ffmpeg.overlay(last_video_stream, track_stream) + final_video = last_video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps) + + track_audio_streams = [] + for i in range(self.timeline.num_audio_tracks): + track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec) + if not track_clips: continue + + track_segments = [] + last_end = track_clips[0].timeline_start_sec + + for clip in track_clips: + gap = clip.timeline_start_sec - last_end + if gap > 0.01: + track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi')) + + a_seg = input_files[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS') + track_segments.append(a_seg) + last_end = clip.timeline_end_sec + + track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_sec * 1000)}ms', all=True)) + + if track_audio_streams: + final_audio = ffmpeg.filter(track_audio_streams, 'amix', inputs=len(track_audio_streams), duration='longest') + else: + final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur}', f='lavfi') + + output_args = {'vcodec': 'libx264', 'acodec': 'aac', 'pix_fmt': 'yuv420p', 'b:v': '5M'} + try: + ffmpeg_cmd = ffmpeg.output(final_video, final_audio, output_path, **output_args).overwrite_output().compile() + self.progress_bar.setVisible(True); self.progress_bar.setValue(0); self.status_label.setText("Exporting...") + self.export_thread = QThread() + self.export_worker = ExportWorker(ffmpeg_cmd, total_dur) + self.export_worker.moveToThread(self.export_thread) + self.export_thread.started.connect(self.export_worker.run_export) + self.export_worker.finished.connect(self.on_export_finished) + self.export_worker.progress.connect(self.progress_bar.setValue) + self.export_worker.finished.connect(self.export_thread.quit) + self.export_worker.finished.connect(self.export_worker.deleteLater) + self.export_thread.finished.connect(self.export_thread.deleteLater) + self.export_thread.finished.connect(self.on_thread_finished_cleanup) + self.export_thread.start() + except ffmpeg.Error as e: + self.status_label.setText(f"FFmpeg error: {e.stderr}") + print(e.stderr) + + def on_export_finished(self, message): + self.status_label.setText(message) + self.progress_bar.setVisible(False) + + def on_thread_finished_cleanup(self): + self.export_thread = None + self.export_worker = None + + def add_dock_widget(self, plugin_instance, widget, title, area=Qt.DockWidgetArea.RightDockWidgetArea, show_on_creation=True): + widget_key = f"plugin_{plugin_instance.name}_{title}".replace(' ', '_').lower() + + dock = QDockWidget(title, self) + dock.setWidget(widget) + self.addDockWidget(area, dock) + + visibility_settings = self.settings.get("window_visibility", {}) + initial_visibility = visibility_settings.get(widget_key, show_on_creation) + + dock.setVisible(initial_visibility) + + action = QAction(title, self, checkable=True) + action.toggled.connect(lambda checked, k=widget_key: self.toggle_widget_visibility(k, checked)) + dock.visibilityChanged.connect(action.setChecked) + + action.setChecked(dock.isVisible()) + + self.windows_menu.addAction(action) + + self.managed_widgets[widget_key] = { + 'widget': dock, + 'name': title, + 'action': action, + 'plugin': plugin_instance.name + } + + return dock + + def update_plugin_ui_visibility(self, plugin_name, is_enabled): + for key, data in self.managed_widgets.items(): + if data.get('plugin') == plugin_name: + data['action'].setVisible(is_enabled) + if not is_enabled: + data['widget'].hide() + + def toggle_plugin(self, name, checked): + if checked: + self.plugin_manager.enable_plugin(name) + else: + self.plugin_manager.disable_plugin(name) + self._save_settings() + + def toggle_plugin_action(self, name, checked): + if name in self.plugin_menu_actions: + action = self.plugin_menu_actions[name] + action.blockSignals(True) + action.setChecked(checked) + action.blockSignals(False) + + def open_manage_plugins_dialog(self): + dialog = ManagePluginsDialog(self.plugin_manager, self) + dialog.app = self + dialog.exec() + + def closeEvent(self, event): + self.is_shutting_down = True + self._save_settings() + self._stop_playback_stream() + if self.export_thread and self.export_thread.isRunning(): + self.export_thread.quit() + self.export_thread.wait() + event.accept() + +if __name__ == '__main__': + app = QApplication(sys.argv) + project_to_load_on_startup = None + if len(sys.argv) > 1: + path = sys.argv[1] + if os.path.exists(path) and path.lower().endswith('.json'): + project_to_load_on_startup = path + print(f"Loading project: {path}") + window = MainWindow(project_to_load=project_to_load_on_startup) + window.show() + sys.exit(app.exec()) \ No newline at end of file diff --git a/videoeditor/plugins.py b/videoeditor/plugins.py new file mode 100644 index 000000000..6e50f0963 --- /dev/null +++ b/videoeditor/plugins.py @@ -0,0 +1,266 @@ +import os +import sys +import importlib.util +import subprocess +import shutil +import git +import json +from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QListWidget, QListWidgetItem, + QPushButton, QLabel, QLineEdit, QMessageBox, QProgressBar, + QDialogButtonBox, QWidget, QCheckBox) +from PyQt6.QtCore import Qt, QObject, pyqtSignal, QThread + +class VideoEditorPlugin: + """ + Base class for all plugins. + Plugins should inherit from this class and be located in 'plugins/plugin_name/main.py'. + The main class in main.py must be named 'Plugin'. + """ + def __init__(self, app_instance): + self.app = app_instance + self.name = "Unnamed Plugin" + self.description = "No description provided." + + def initialize(self): + """Called once when the plugin is loaded by the PluginManager.""" + pass + + def enable(self): + """Called when the plugin is enabled by the user (e.g., checking the box in the menu).""" + pass + + def disable(self): + """Called when the plugin is disabled by the user.""" + pass + +class PluginManager: + """Manages the discovery, loading, and lifecycle of plugins.""" + def __init__(self, main_app): + self.app = main_app + self.plugins_dir = "plugins" + self.plugins = {} # { 'plugin_name': {'instance': plugin_instance, 'enabled': False, 'module_path': path} } + if not os.path.exists(self.plugins_dir): + os.makedirs(self.plugins_dir) + + def discover_and_load_plugins(self): + """Scans the plugins directory, loads valid plugins, and calls their initialize method.""" + for plugin_name in os.listdir(self.plugins_dir): + plugin_path = os.path.join(self.plugins_dir, plugin_name) + main_py_path = os.path.join(plugin_path, 'main.py') + if os.path.isdir(plugin_path) and os.path.exists(main_py_path): + try: + spec = importlib.util.spec_from_file_location(f"plugins.{plugin_name}.main", main_py_path) + module = importlib.util.module_from_spec(spec) + + sys.path.insert(0, plugin_path) + spec.loader.exec_module(module) + sys.path.pop(0) + + if hasattr(module, 'Plugin'): + plugin_class = getattr(module, 'Plugin') + instance = plugin_class(self.app) + instance.initialize() + self.plugins[instance.name] = { + 'instance': instance, + 'enabled': False, + 'module_path': plugin_path + } + print(f"Discovered and loaded plugin: {instance.name}") + else: + print(f"Warning: {main_py_path} does not have a 'Plugin' class.") + except Exception as e: + print(f"Error loading plugin {plugin_name}: {e}") + + def load_enabled_plugins_from_settings(self, enabled_plugins_list): + """Enables plugins based on the loaded settings.""" + for name in enabled_plugins_list: + if name in self.plugins: + self.enable_plugin(name) + + def get_enabled_plugin_names(self): + """Returns a list of names of all enabled plugins.""" + return [name for name, data in self.plugins.items() if data['enabled']] + + def enable_plugin(self, name): + """Enables a specific plugin by name and updates the UI.""" + if name in self.plugins and not self.plugins[name]['enabled']: + self.plugins[name]['instance'].enable() + self.plugins[name]['enabled'] = True + # Notify the app to update the menu's checkmark + self.app.toggle_plugin_action(name, True) + self.app.update_plugin_ui_visibility(name, True) + + def disable_plugin(self, name): + """Disables a specific plugin by name and updates the UI.""" + if name in self.plugins and self.plugins[name]['enabled']: + self.plugins[name]['instance'].disable() + self.plugins[name]['enabled'] = False + # Notify the app to update the menu's checkmark + self.app.toggle_plugin_action(name, False) + self.app.update_plugin_ui_visibility(name, False) + + def uninstall_plugin(self, name): + """Uninstalls (deletes) a plugin by name.""" + if name in self.plugins: + path = self.plugins[name]['module_path'] + if self.plugins[name]['enabled']: + self.disable_plugin(name) + + try: + shutil.rmtree(path) + del self.plugins[name] + return True + except OSError as e: + print(f"Error removing plugin directory {path}: {e}") + return False + return False + + +class InstallWorker(QObject): + finished = pyqtSignal(str, bool) + + def __init__(self, url, target_dir): + super().__init__() + self.url = url + self.target_dir = target_dir + + def run(self): + try: + repo_name = self.url.split('/')[-1].replace('.git', '') + clone_path = os.path.join(self.target_dir, repo_name) + if os.path.exists(clone_path): + self.finished.emit(f"Directory '{repo_name}' already exists.", False) + return + + git.Repo.clone_from(self.url, clone_path) + + req_path = os.path.join(clone_path, 'requirements.txt') + if os.path.exists(req_path): + print("Installing plugin requirements...") + subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", req_path]) + + self.finished.emit(f"Plugin '{repo_name}' installed successfully. Please restart the application.", True) + except Exception as e: + self.finished.emit(f"Installation failed: {e}", False) + + +class ManagePluginsDialog(QDialog): + def __init__(self, plugin_manager, parent=None): + super().__init__(parent) + self.plugin_manager = plugin_manager + self.setWindowTitle("Manage Plugins") + self.setMinimumSize(500, 400) + + self.plugin_checkboxes = {} + + layout = QVBoxLayout(self) + + layout.addWidget(QLabel("Installed Plugins:")) + self.list_widget = QListWidget() + layout.addWidget(self.list_widget) + + install_layout = QVBoxLayout() + install_layout.addWidget(QLabel("Install new plugin from GitHub URL:")) + self.url_input = QLineEdit() + self.url_input.setPlaceholderText("e.g., https://github.com/user/repo.git") + self.install_btn = QPushButton("Install") + install_layout.addWidget(self.url_input) + install_layout.addWidget(self.install_btn) + layout.addLayout(install_layout) + + self.status_label = QLabel("Ready.") + layout.addWidget(self.status_label) + + # Dialog buttons + self.button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Save | QDialogButtonBox.StandardButton.Cancel) + layout.addWidget(self.button_box) + + # Connections + self.install_btn.clicked.connect(self.install_plugin) + self.button_box.accepted.connect(self.save_changes) + self.button_box.rejected.connect(self.reject) + + self.populate_list() + + def populate_list(self): + self.list_widget.clear() + self.plugin_checkboxes.clear() + + for name, data in sorted(self.plugin_manager.plugins.items()): + item_widget = QWidget() + item_layout = QHBoxLayout(item_widget) + item_layout.setContentsMargins(5, 5, 5, 5) + + checkbox = QCheckBox(name) + checkbox.setChecked(data['enabled']) + checkbox.setToolTip(data['instance'].description) + self.plugin_checkboxes[name] = checkbox + item_layout.addWidget(checkbox, 1) + + uninstall_btn = QPushButton("Uninstall") + uninstall_btn.setFixedWidth(80) + uninstall_btn.clicked.connect(lambda _, n=name: self.handle_uninstall(n)) + item_layout.addWidget(uninstall_btn) + + item_widget.setLayout(item_layout) + + list_item = QListWidgetItem(self.list_widget) + list_item.setSizeHint(item_widget.sizeHint()) + self.list_widget.addItem(list_item) + self.list_widget.setItemWidget(list_item, item_widget) + + def save_changes(self): + for name, checkbox in self.plugin_checkboxes.items(): + if name not in self.plugin_manager.plugins: + continue + + is_checked = checkbox.isChecked() + is_currently_enabled = self.plugin_manager.plugins[name]['enabled'] + + if is_checked and not is_currently_enabled: + self.plugin_manager.enable_plugin(name) + elif not is_checked and is_currently_enabled: + self.plugin_manager.disable_plugin(name) + + self.plugin_manager.app._save_settings() + self.accept() + + def handle_uninstall(self, name): + reply = QMessageBox.question(self, "Confirm Uninstall", + f"Are you sure you want to permanently delete the plugin '{name}'?", + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No) + + if reply == QMessageBox.StandardButton.Yes: + if self.plugin_manager.uninstall_plugin(name): + self.status_label.setText(f"Uninstalled '{name}'. Restart the application to fully remove it from the menu.") + self.populate_list() + else: + self.status_label.setText(f"Failed to uninstall '{name}'.") + + def install_plugin(self): + url = self.url_input.text().strip() + if not url.endswith(".git"): + QMessageBox.warning(self, "Invalid URL", "Please provide a valid git repository URL (ending in .git).") + return + + self.install_btn.setEnabled(False) + self.status_label.setText(f"Cloning from {url}...") + + self.install_thread = QThread() + self.install_worker = InstallWorker(url, self.plugin_manager.plugins_dir) + self.install_worker.moveToThread(self.install_thread) + + self.install_thread.started.connect(self.install_worker.run) + self.install_worker.finished.connect(self.on_install_finished) + self.install_worker.finished.connect(self.install_thread.quit) + self.install_worker.finished.connect(self.install_worker.deleteLater) + self.install_thread.finished.connect(self.install_thread.deleteLater) + + self.install_thread.start() + + def on_install_finished(self, message, success): + self.status_label.setText(message) + self.install_btn.setEnabled(True) + if success: + self.url_input.clear() \ No newline at end of file diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py new file mode 100644 index 000000000..873cd5e56 --- /dev/null +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -0,0 +1,308 @@ +import sys +import os +import tempfile +import shutil +import requests +from pathlib import Path + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QFormLayout, + QLineEdit, QPushButton, QLabel, QMessageBox, QCheckBox +) +from PyQt6.QtGui import QImage, QPixmap +from PyQt6.QtCore import QTimer, pyqtSignal, Qt, QSize + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from plugins import VideoEditorPlugin + + +API_BASE_URL = "http://127.0.0.1:5100" + +class WgpClientWidget(QWidget): + generation_complete = pyqtSignal(str) + + def __init__(self): + super().__init__() + + self.last_known_output = None + + layout = QVBoxLayout(self) + form_layout = QFormLayout() + + self.model_input = QLineEdit() + self.model_input.setPlaceholderText("Optional (default: i2v_2_2)") + + previews_layout = QHBoxLayout() + start_preview_layout = QVBoxLayout() + start_preview_layout.addWidget(QLabel("Start Frame")) + self.start_frame_preview = QLabel("N/A") + self.start_frame_preview.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.start_frame_preview.setFixedSize(160, 90) + self.start_frame_preview.setStyleSheet("background-color: #222; color: #888;") + start_preview_layout.addWidget(self.start_frame_preview) + previews_layout.addLayout(start_preview_layout) + + end_preview_layout = QVBoxLayout() + end_preview_layout.addWidget(QLabel("End Frame")) + self.end_frame_preview = QLabel("N/A") + self.end_frame_preview.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.end_frame_preview.setFixedSize(160, 90) + self.end_frame_preview.setStyleSheet("background-color: #222; color: #888;") + end_preview_layout.addWidget(self.end_frame_preview) + previews_layout.addLayout(end_preview_layout) + + self.start_frame_input = QLineEdit() + self.end_frame_input = QLineEdit() + self.duration_input = QLineEdit() + + self.autostart_checkbox = QCheckBox("Start Generation Immediately") + self.autostart_checkbox.setChecked(True) + + self.generate_button = QPushButton("Generate") + + self.status_label = QLabel("Status: Idle") + self.output_label = QLabel("Latest Output File: None") + self.output_label.setWordWrap(True) + + layout.addLayout(previews_layout) + form_layout.addRow("Model Type:", self.model_input) + form_layout.addRow(self.autostart_checkbox) + form_layout.addRow(self.generate_button) + + layout.addLayout(form_layout) + layout.addWidget(self.status_label) + layout.addWidget(self.output_label) + + self.generate_button.clicked.connect(self.generate) + + self.poll_timer = QTimer(self) + self.poll_timer.setInterval(3000) + self.poll_timer.timeout.connect(self.poll_for_output) + + def set_previews(self, start_pixmap, end_pixmap): + if start_pixmap: + self.start_frame_preview.setPixmap(start_pixmap) + else: + self.start_frame_preview.setText("N/A") + self.start_frame_preview.setPixmap(QPixmap()) + + if end_pixmap: + self.end_frame_preview.setPixmap(end_pixmap) + else: + self.end_frame_preview.setText("N/A") + self.end_frame_preview.setPixmap(QPixmap()) + + def start_polling(self): + self.poll_timer.start() + self.check_server_status() + + def stop_polling(self): + self.poll_timer.stop() + + def handle_api_error(self, response, action="performing action"): + try: + error_msg = response.json().get("error", "Unknown error") + except requests.exceptions.JSONDecodeError: + error_msg = response.text + self.status_label.setText(f"Status: Error {action}: {error_msg}") + QMessageBox.warning(self, "API Error", f"Failed while {action}.\n\nServer response:\n{error_msg}") + + def check_server_status(self): + try: + requests.get(f"{API_BASE_URL}/api/latest_output", timeout=1) + self.status_label.setText("Status: Connected to WanGP server.") + except requests.exceptions.ConnectionError: + self.status_label.setText("Status: Error - Cannot connect to WanGP server.") + QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.") + + + def generate(self): + payload = {} + + model_type = self.model_input.text().strip() + if model_type: + payload['model_type'] = model_type + + if self.start_frame_input.text(): + payload['start_frame'] = self.start_frame_input.text() + if self.end_frame_input.text(): + payload['end_frame'] = self.end_frame_input.text() + + if self.duration_input.text(): + payload['duration_sec'] = self.duration_input.text() + + payload['start_generation'] = self.autostart_checkbox.isChecked() + + self.status_label.setText("Status: Sending parameters...") + try: + response = requests.post(f"{API_BASE_URL}/api/generate", json=payload) + if response.status_code == 200: + if payload['start_generation']: + self.status_label.setText("Status: Parameters set. Generation sent. Polling...") + else: + self.status_label.setText("Status: Parameters set. Waiting for manual start.") + else: + self.handle_api_error(response, "setting parameters") + except requests.exceptions.RequestException as e: + self.status_label.setText(f"Status: Connection error: {e}") + + def poll_for_output(self): + try: + response = requests.get(f"{API_BASE_URL}/api/latest_output") + if response.status_code == 200: + data = response.json() + latest_path = data.get("latest_output_path") + + if latest_path and latest_path != self.last_known_output: + self.last_known_output = latest_path + self.output_label.setText(f"Latest Output File:\n{latest_path}") + self.status_label.setText("Status: New output received! Inserting clip...") + self.generation_complete.emit(latest_path) + else: + if "Error" not in self.status_label.text() and "waiting" not in self.status_label.text().lower(): + self.status_label.setText("Status: Polling for output...") + except requests.exceptions.RequestException: + if "Error" not in self.status_label.text(): + self.status_label.setText("Status: Polling... (Connection issue)") + +class Plugin(VideoEditorPlugin): + def initialize(self): + self.name = "AI Frame Joiner" + self.description = "Uses a local AI server to generate a video between two frames." + self.client_widget = WgpClientWidget() + self.dock_widget = None + self.active_region = None + self.temp_dir = None + self.client_widget.generation_complete.connect(self.insert_generated_clip) + + def enable(self): + if not self.dock_widget: + self.dock_widget = self.app.add_dock_widget(self, self.client_widget, "AI Frame Joiner", show_on_creation=False) + + self.dock_widget.hide() + + self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu) + self.client_widget.start_polling() + + def disable(self): + try: + self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu) + except TypeError: + pass + self._cleanup_temp_dir() + self.client_widget.stop_polling() + + def _cleanup_temp_dir(self): + if self.temp_dir and os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + self.temp_dir = None + + def _reset_state(self): + self.active_region = None + self._cleanup_temp_dir() + self.client_widget.status_label.setText("Status: Idle") + self.client_widget.output_label.setText("Latest Output File: None") + self.client_widget.set_previews(None, None) + + def on_timeline_context_menu(self, menu, event): + region = self.app.timeline_widget.get_region_at_pos(event.pos()) + if region: + menu.addSeparator() + action = menu.addAction("Join Frames With AI") + action.triggered.connect(lambda: self.setup_generator_for_region(region)) + + def setup_generator_for_region(self, region): + self._reset_state() + self.active_region = region + start_sec, end_sec = region + + start_data, w, h = self.app.get_frame_data_at_time(start_sec) + end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + + if not start_data or not end_data: + QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames for the selected region.") + return + + preview_size = QSize(160, 90) + start_pixmap, end_pixmap = None, None + + try: + start_img = QImage(start_data, w, h, QImage.Format.Format_RGB888) + start_pixmap = QPixmap.fromImage(start_img).scaled(preview_size, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + + end_img = QImage(end_data, w, h, QImage.Format.Format_RGB888) + end_pixmap = QPixmap.fromImage(end_img).scaled(preview_size, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + except Exception as e: + QMessageBox.critical(self.app, "Image Error", f"Could not create preview images: {e}") + + self.client_widget.set_previews(start_pixmap, end_pixmap) + + try: + self.temp_dir = tempfile.mkdtemp(prefix="ai_joiner_") + start_img_path = os.path.join(self.temp_dir, "start_frame.png") + end_img_path = os.path.join(self.temp_dir, "end_frame.png") + + QImage(start_data, w, h, QImage.Format.Format_RGB888).save(start_img_path) + QImage(end_data, w, h, QImage.Format.Format_RGB888).save(end_img_path) + except Exception as e: + QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}") + self._cleanup_temp_dir() + return + + duration_sec = end_sec - start_sec + self.client_widget.duration_input.setText(str(duration_sec)) + + self.client_widget.start_frame_input.setText(start_img_path) + self.client_widget.end_frame_input.setText(end_img_path) + self.client_widget.status_label.setText(f"Status: Ready for region {start_sec:.2f}s - {end_sec:.2f}s") + + self.dock_widget.show() + self.dock_widget.raise_() + + def insert_generated_clip(self, video_path): + if not self.active_region: + self.client_widget.status_label.setText("Status: Error - No active region to insert into.") + return + + if not os.path.exists(video_path): + self.client_widget.status_label.setText(f"Status: Error - Output file not found: {video_path}") + return + + start_sec, end_sec = self.active_region + duration = end_sec - start_sec + self.app.status_label.setText(f"Inserting AI clip: {os.path.basename(video_path)}") + + try: + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) + + clips_to_remove = [ + c for c in self.app.timeline.clips + if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec + ] + for clip in clips_to_remove: + if clip in self.app.timeline.clips: + self.app.timeline.clips.remove(clip) + + self.app._add_clip_to_timeline( + source_path=video_path, + timeline_start_sec=start_sec, + clip_start_sec=0, + duration_sec=duration, + video_track_index=1, + audio_track_index=None + ) + + self.app.status_label.setText("AI clip inserted successfully.") + self.client_widget.status_label.setText("Status: Success! Clip inserted.") + self.app.prune_empty_tracks() + + except Exception as e: + error_message = f"Error during clip insertion: {e}" + self.app.status_label.setText(error_message) + self.client_widget.status_label.setText("Status: Failed to insert clip.") + print(error_message) + + finally: + self._cleanup_temp_dir() + self.active_region = None \ No newline at end of file From d7f0da36da3740b1e5769eab257154d9bde16189 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 04:14:11 +1100 Subject: [PATCH 086/155] adjust run command --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 2b409148a..fb88c7acd 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,8 @@ python main.py **Run the video editor:** ```bash -python videoeditor\main.py +cd videoeditor +python main.py ``` From 7e7007e29c7cf90246cad18c45912a06ff5e54c0 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 04:15:20 +1100 Subject: [PATCH 087/155] forgot to include server --- main.py | 2006 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 2006 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 000000000..06292b315 --- /dev/null +++ b/main.py @@ -0,0 +1,2006 @@ +import sys +import os +import threading +import time +import json +import re +from unittest.mock import MagicMock + +# --- Start of Gradio Hijacking --- +# This block creates a mock Gradio module. When wgp.py is imported, +# all calls to `gr.*` will be intercepted by these mock objects, +# preventing any UI from being built and allowing us to use the +# backend logic directly. + +class MockGradioComponent(MagicMock): + """A smarter mock that captures constructor arguments.""" + def __init__(self, *args, **kwargs): + super().__init__(name=f"gr.{kwargs.get('elem_id', 'component')}") + # Store the kwargs so we can inspect them later + self.kwargs = kwargs + self.value = kwargs.get('value') + self.choices = kwargs.get('choices') + + # Mock chaining methods like .click(), .then(), etc. + for method in ['then', 'change', 'click', 'input', 'select', 'upload', 'mount', 'launch', 'on', 'release']: + setattr(self, method, lambda *a, **kw: self) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + +class MockGradioError(Exception): + pass + +class MockGradioModule: # No longer inherits from MagicMock + def __getattr__(self, name): + if name == 'Error': + return lambda *args, **kwargs: MockGradioError(*args) + + # Nullify functions that show pop-ups + if name in ['Info', 'Warning']: + return lambda *args, **kwargs: print(f"Intercepted gr.{name}:", *args) + + return lambda *args, **kwargs: MockGradioComponent(*args, **kwargs) + +sys.modules['gradio'] = MockGradioModule() +sys.modules['gradio.gallery'] = MockGradioModule() # Also mock any submodules used +sys.modules['shared.gradio.gallery'] = MockGradioModule() +# --- End of Gradio Hijacking --- + +# Global placeholder for the wgp module. Will be None if import fails. +wgp = None + +# Load configuration and attempt to import wgp +MAIN_CONFIG_FILE = 'main_config.json' +main_config = {} + +def load_main_config(): + global main_config + try: + with open(MAIN_CONFIG_FILE, 'r') as f: + main_config = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + current_folder = os.path.dirname(os.path.abspath(sys.argv[0])) + main_config = {'wgp_path': current_folder} + +def save_main_config(): + global main_config + try: + with open(MAIN_CONFIG_FILE, 'w') as f: + json.dump(main_config, f, indent=4) + except Exception as e: + print(f"Error saving main_config.json: {e}") + +def setup_and_import_wgp(): + """Adds configured path to sys.path and tries to import wgp.""" + global wgp + wgp_path = main_config.get('wgp_path') + if wgp_path and os.path.isdir(wgp_path) and os.path.isfile(os.path.join(wgp_path, 'wgp.py')): + if wgp_path not in sys.path: + sys.path.insert(0, wgp_path) + try: + import wgp as wgp_module + wgp = wgp_module + return True + except ImportError as e: + print(f"Error: Failed to import wgp.py from the configured path '{wgp_path}'.\nDetails: {e}") + wgp = None + return False + else: + print("Info: WAN2GP folder path not set or invalid. Please configure it in File > Settings.") + wgp = None + return False + +# Load config and attempt import at script start +load_main_config() +wgp_loaded = setup_and_import_wgp() + +# Now import PyQt6 components +from PyQt6.QtWidgets import ( + QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, + QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox, + QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem, + QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem, + QMessageBox, QRadioButton, QDialog +) +from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QTimer, pyqtSlot +from PyQt6.QtGui import QPixmap, QDropEvent, QAction +from PIL.ImageQt import ImageQt + + +class SettingsDialog(QDialog): + """Dialog to configure application settings like the wgp.py path.""" + def __init__(self, config, parent=None): + super().__init__(parent) + self.config = config + self.setWindowTitle("Settings") + self.setMinimumWidth(500) + + layout = QVBoxLayout(self) + form_layout = QFormLayout() + + path_layout = QHBoxLayout() + self.wgp_path_edit = QLineEdit(self.config.get('wgp_path', '')) + self.wgp_path_edit.setPlaceholderText("Path to the folder containing wgp.py") + browse_btn = QPushButton("Browse...") + browse_btn.clicked.connect(self.browse_for_wgp_folder) + path_layout.addWidget(self.wgp_path_edit) + path_layout.addWidget(browse_btn) + + form_layout.addRow("WAN2GP Folder Path:", path_layout) + layout.addLayout(form_layout) + + button_layout = QHBoxLayout() + save_btn = QPushButton("Save") + save_restart_btn = QPushButton("Save and Restart") + cancel_btn = QPushButton("Cancel") + button_layout.addStretch() + button_layout.addWidget(save_btn) + button_layout.addWidget(save_restart_btn) + button_layout.addWidget(cancel_btn) + layout.addLayout(button_layout) + + save_btn.clicked.connect(self.save_and_close) + save_restart_btn.clicked.connect(self.save_and_restart) + cancel_btn.clicked.connect(self.reject) + + def browse_for_wgp_folder(self): + directory = QFileDialog.getExistingDirectory(self, "Select WAN2GP Folder") + if directory: + self.wgp_path_edit.setText(directory) + + def validate_path(self, path): + if not path or not os.path.isdir(path) or not os.path.isfile(os.path.join(path, 'wgp.py')): + QMessageBox.warning(self, "Invalid Path", "The selected folder does not contain 'wgp.py'. Please select the correct WAN2GP folder.") + return False + return True + + def _save_config(self): + path = self.wgp_path_edit.text() + if not self.validate_path(path): + return False + self.config['wgp_path'] = path + save_main_config() + return True + + def save_and_close(self): + if self._save_config(): + QMessageBox.information(self, "Settings Saved", "Settings have been saved. Please restart the application for changes to take effect.") + self.accept() + + def save_and_restart(self): + if self._save_config(): + self.parent().close() # Close the main window before restarting + os.execv(sys.executable, [sys.executable] + sys.argv) + + +class QueueTableWidget(QTableWidget): + """A QTableWidget with drag-and-drop reordering for rows.""" + rowsMoved = pyqtSignal(int, int) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setDragEnabled(True) + self.setAcceptDrops(True) + self.setDropIndicatorShown(True) + self.setDragDropMode(self.DragDropMode.InternalMove) + self.setSelectionBehavior(self.SelectionBehavior.SelectRows) + self.setSelectionMode(self.SelectionMode.SingleSelection) + + def dropEvent(self, event: QDropEvent): + if event.source() == self and event.dropAction() == Qt.DropAction.MoveAction: + source_row = self.currentRow() + target_item = self.itemAt(event.position().toPoint()) + dest_row = target_item.row() if target_item else self.rowCount() + + # Adjust destination row if moving down + if source_row < dest_row: + dest_row -=1 + + if source_row != dest_row: + self.rowsMoved.emit(source_row, dest_row) + + event.acceptProposedAction() + else: + super().dropEvent(event) + + +class Worker(QObject): + progress = pyqtSignal(list) + status = pyqtSignal(str) + preview = pyqtSignal(object) + output = pyqtSignal() + finished = pyqtSignal() + error = pyqtSignal(str) + + def __init__(self, state): + super().__init__() + self.state = state + self._is_running = True + self._last_progress_phase = None + self._last_preview = None + + def send_cmd(self, cmd, data=None): + if not self._is_running: + return + + def run(self): + def generation_target(): + try: + for _ in wgp.process_tasks(self.state): + if self._is_running: + self.output.emit() + else: + break + except Exception as e: + import traceback + print("Error in generation thread:") + traceback.print_exc() + if "gradio.Error" in str(type(e)): + self.error.emit(str(e)) + else: + self.error.emit(f"An unexpected error occurred: {e}") + finally: + self._is_running = False + + gen_thread = threading.Thread(target=generation_target, daemon=True) + gen_thread.start() + + while self._is_running: + gen = self.state.get('gen', {}) + + current_phase = gen.get("progress_phase") + if current_phase and current_phase != self._last_progress_phase: + self._last_progress_phase = current_phase + + phase_name, step = current_phase + total_steps = gen.get("num_inference_steps", 1) + high_level_status = gen.get("progress_status", "") + + status_msg = wgp.merge_status_context(high_level_status, phase_name) + + progress_args = [(step, total_steps), status_msg] + self.progress.emit(progress_args) + + preview_img = gen.get('preview') + if preview_img is not None and preview_img is not self._last_preview: + self._last_preview = preview_img + self.preview.emit(preview_img) + gen['preview'] = None + + time.sleep(0.1) + + gen_thread.join() + self.finished.emit() + +class ApiBridge(QObject): + """ + An object that lives in the main thread to receive signals from the API thread + and forward them to the MainWindow's slots. + """ + # --- CHANGE: Signal now includes model_type and duration_sec --- + generateSignal = pyqtSignal(object, object, object, object, bool) + +class MainWindow(QMainWindow): + def __init__(self): + super().__init__() + + self.api_bridge = ApiBridge() + + self.widgets = {} + self.state = {} + self.worker = None + self.thread = None + self.lora_map = {} + self.full_resolution_choices = [] + self.latest_output_path = None + + self.setup_menu() + + if not wgp: + self.setWindowTitle("WanGP - Setup Required") + self.setGeometry(100, 100, 600, 200) + self.setup_placeholder_ui() + QTimer.singleShot(100, lambda: self.show_settings_dialog(first_time=True)) + else: + self.setWindowTitle(f"WanGP v{wgp.WanGP_version} - Qt Interface") + self.setGeometry(100, 100, 1400, 950) + self.setup_full_ui() + self.apply_initial_config() + self.connect_signals() + self.init_wgp_state() + + # --- CHANGE: Removed setModelSignal as it's no longer needed --- + self.api_bridge.generateSignal.connect(self._api_generate) + + @pyqtSlot(str) + def _api_set_model(self, model_type): + """This slot is executed in the main GUI thread.""" + if not model_type or not wgp: return + + # 1. Check if the model is valid by looking it up in the master model definition dictionary. + if model_type not in wgp.models_def: + print(f"API Error: Model type '{model_type}' is not a valid model.") + return + + # 2. Check if already selected to avoid unnecessary UI refreshes. + if self.state.get('model_type') == model_type: + print(f"API: Model is already set to {model_type}.") + return + + # 3. Redraw all model dropdowns to ensure the correct hierarchy is displayed + # and the target model is selected. This function handles finding the + # correct Family and Base model for the given finetune model_type. + self.update_model_dropdowns(model_type) + + # 4. Manually trigger the logic that normally runs when the user selects a model. + # This is necessary because update_model_dropdowns blocks signals. + self._on_model_changed() + + # 5. Final check to see if the model was actually set. + if self.state.get('model_type') == model_type: + print(f"API: Successfully set model to {model_type}.") + else: + # This could happen if update_model_dropdowns silently fails to find the model. + print(f"API Error: Failed to set model to '{model_type}'. The model might be hidden by your current configuration.") + + # --- CHANGE: Slot now accepts model_type and duration_sec, and calculates frame count --- + @pyqtSlot(object, object, object, object, bool) + def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_generation): + """This slot is executed in the main GUI thread.""" + # 1. Set model if a new one is provided + if model_type: + self._api_set_model(model_type) + + # 2. Set frame inputs + if start_frame: + self.widgets['mode_s'].setChecked(True) + self.widgets['image_start'].setText(start_frame) + + if end_frame: + self.widgets['image_end_checkbox'].setChecked(True) + self.widgets['image_end'].setText(end_frame) + + # 3. Calculate video length in frames based on duration and model FPS + if duration_sec is not None: + try: + duration = float(duration_sec) + + # Get base FPS by parsing the "Force FPS" dropdown's default text (e.g., "Model Default (16 fps)") + base_fps = 16 # Fallback + fps_text = self.widgets['force_fps'].itemText(0) + match = re.search(r'\((\d+)\s*fps\)', fps_text) + if match: + base_fps = int(match.group(1)) + + # Temporal upsampling creates more frames in post-processing, so we must account for it here. + upsample_setting = self.widgets['temporal_upsampling'].currentData() + multiplier = 1.0 + if upsample_setting == "rife2": + multiplier = 2.0 + elif upsample_setting == "rife4": + multiplier = 4.0 + + # The number of frames the model needs to generate + video_length_frames = int(duration * base_fps * multiplier) + + self.widgets['video_length'].setValue(video_length_frames) + print(f"API: Calculated video length: {video_length_frames} frames for {duration:.2f}s @ {base_fps*multiplier:.0f} effective FPS.") + + except (ValueError, TypeError) as e: + print(f"API Error: Invalid duration_sec '{duration_sec}': {e}") + + # 4. Conditionally start generation + if start_generation: + self.generate_btn.click() + print("API: Generation started.") + else: + print("API: Parameters set without starting generation.") + + + def setup_menu(self): + menu_bar = self.menuBar() + file_menu = menu_bar.addMenu("&File") + + settings_action = QAction("&Settings", self) + settings_action.triggered.connect(self.show_settings_dialog) + file_menu.addAction(settings_action) + + def show_settings_dialog(self, first_time=False): + dialog = SettingsDialog(main_config, self) + if first_time: + dialog.setWindowTitle("Initial Setup: Configure WAN2GP Path") + dialog.exec() + + def setup_placeholder_ui(self): + main_widget = QWidget() + self.setCentralWidget(main_widget) + layout = QVBoxLayout(main_widget) + + placeholder_label = QLabel( + "

Welcome to WanGP

" + "

The path to your WAN2GP installation (the folder containing wgp.py) is not set.

" + "

Please go to File > Settings to configure the path.

" + ) + placeholder_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + placeholder_label.setWordWrap(True) + layout.addWidget(placeholder_label) + + def setup_full_ui(self): + main_widget = QWidget() + self.setCentralWidget(main_widget) + main_layout = QVBoxLayout(main_widget) + + self.header_info = QLabel("Header Info") + main_layout.addWidget(self.header_info) + + self.tabs = QTabWidget() + main_layout.addWidget(self.tabs) + + self.setup_generator_tab() + self.setup_config_tab() + + def create_widget(self, widget_class, name, *args, **kwargs): + widget = widget_class(*args, **kwargs) + self.widgets[name] = widget + return widget + + def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1.0, precision=1): + container = QWidget() + hbox = QHBoxLayout(container) + hbox.setContentsMargins(0, 0, 0, 0) + + slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal) + slider.setRange(min_val, max_val) + slider.setValue(int(initial_val * scale)) + + value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}") + value_label.setMinimumWidth(50) + + slider.valueChanged.connect( + lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}") + ) + + hbox.addWidget(slider) + hbox.addWidget(value_label) + return container + + def _create_file_input(self, name, label_text): + container = self.create_widget(QWidget, f"{name}_container") + hbox = QHBoxLayout(container) + hbox.setContentsMargins(0, 0, 0, 0) + + line_edit = self.create_widget(QLineEdit, name) + line_edit.setReadOnly(False) # Allow user to paste paths + line_edit.setPlaceholderText("No file selected or path pasted") + + button = QPushButton("Browse...") + + def open_dialog(): + # Allow selecting multiple files for reference images + if "refs" in name: + filenames, _ = QFileDialog.getOpenFileNames(self, f"Select {label_text}") + if filenames: + line_edit.setText(";".join(filenames)) + else: + filename, _ = QFileDialog.getOpenFileName(self, f"Select {label_text}") + if filename: + line_edit.setText(filename) + + button.clicked.connect(open_dialog) + + clear_button = QPushButton("X") + clear_button.setFixedWidth(30) + clear_button.clicked.connect(lambda: line_edit.clear()) + + hbox.addWidget(QLabel(f"{label_text}:")) + hbox.addWidget(line_edit, 1) + hbox.addWidget(button) + hbox.addWidget(clear_button) + return container + + def setup_generator_tab(self): + gen_tab = QWidget() + self.tabs.addTab(gen_tab, "Video Generator") + gen_layout = QHBoxLayout(gen_tab) + + left_panel = QWidget() + left_layout = QVBoxLayout(left_panel) + gen_layout.addWidget(left_panel, 1) + + right_panel = QWidget() + right_layout = QVBoxLayout(right_panel) + gen_layout.addWidget(right_panel, 1) + + # Left Panel (Inputs) + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + left_layout.addWidget(scroll_area) + + options_widget = QWidget() + scroll_area.setWidget(options_widget) + options_layout = QVBoxLayout(options_widget) + + # Model Selection + model_layout = QHBoxLayout() + self.widgets['model_family'] = QComboBox() + self.widgets['model_base_type_choice'] = QComboBox() + self.widgets['model_choice'] = QComboBox() + model_layout.addWidget(QLabel("Model:")) + model_layout.addWidget(self.widgets['model_family'], 2) + model_layout.addWidget(self.widgets['model_base_type_choice'], 3) + model_layout.addWidget(self.widgets['model_choice'], 3) + options_layout.addLayout(model_layout) + + # Prompt + options_layout.addWidget(QLabel("Prompt:")) + self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100) + options_layout.addWidget(self.widgets['prompt']) + + options_layout.addWidget(QLabel("Negative Prompt:")) + self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60) + options_layout.addWidget(self.widgets['negative_prompt']) + + # Basic controls + basic_group = QGroupBox("Basic Options") + basic_layout = QFormLayout(basic_group) + + res_container = QWidget() + res_hbox = QHBoxLayout(res_container) + res_hbox.setContentsMargins(0, 0, 0, 0) + res_hbox.addWidget(self.create_widget(QComboBox, 'resolution_group'), 2) + res_hbox.addWidget(self.create_widget(QComboBox, 'resolution'), 3) + basic_layout.addRow("Resolution:", res_container) + + basic_layout.addRow("Video Length:", self._create_slider_with_label('video_length', 1, 737, 81, 1.0, 0)) + basic_layout.addRow("Inference Steps:", self._create_slider_with_label('num_inference_steps', 1, 100, 30, 1.0, 0)) + basic_layout.addRow("Seed:", self.create_widget(QLineEdit, 'seed', '-1')) + options_layout.addWidget(basic_group) + + # Generation Mode and Input Options + mode_options_group = QGroupBox("Generation Mode & Input Options") + mode_options_layout = QVBoxLayout(mode_options_group) + + mode_hbox = QHBoxLayout() + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_t', "Text Prompt Only")) + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_s', "Start with Image")) + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_v', "Continue Video")) + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_l', "Continue Last Video")) + self.widgets['mode_t'].setChecked(True) + mode_options_layout.addLayout(mode_hbox) + + options_hbox = QHBoxLayout() + options_hbox.addWidget(self.create_widget(QCheckBox, 'image_end_checkbox', "Use End Image")) + options_hbox.addWidget(self.create_widget(QCheckBox, 'control_video_checkbox', "Use Control Video")) + options_hbox.addWidget(self.create_widget(QCheckBox, 'ref_image_checkbox', "Use Reference Image(s)")) + mode_options_layout.addLayout(options_hbox) + options_layout.addWidget(mode_options_group) + + # Dynamic Inputs + inputs_group = QGroupBox("Inputs") + inputs_layout = QVBoxLayout(inputs_group) + inputs_layout.addWidget(self._create_file_input('image_start', "Start Image")) + inputs_layout.addWidget(self._create_file_input('image_end', "End Image")) + inputs_layout.addWidget(self._create_file_input('video_source', "Source Video")) + inputs_layout.addWidget(self._create_file_input('video_guide', "Control Video")) + inputs_layout.addWidget(self._create_file_input('video_mask', "Video Mask")) + inputs_layout.addWidget(self._create_file_input('image_refs', "Reference Image(s)")) + denoising_row = QFormLayout() + denoising_row.addRow("Denoising Strength:", self._create_slider_with_label('denoising_strength', 0, 100, 50, 100.0, 2)) + inputs_layout.addLayout(denoising_row) + options_layout.addWidget(inputs_group) + + # Advanced controls + self.advanced_group = self.create_widget(QGroupBox, 'advanced_group', "Advanced Options") + self.advanced_group.setCheckable(True) + self.advanced_group.setChecked(False) + advanced_layout = QVBoxLayout(self.advanced_group) + + advanced_tabs = self.create_widget(QTabWidget, 'advanced_tabs') + advanced_layout.addWidget(advanced_tabs) + + self._setup_adv_tab_general(advanced_tabs) + self._setup_adv_tab_loras(advanced_tabs) + self._setup_adv_tab_speed(advanced_tabs) + self._setup_adv_tab_postproc(advanced_tabs) + self._setup_adv_tab_audio(advanced_tabs) + self._setup_adv_tab_quality(advanced_tabs) + self._setup_adv_tab_sliding_window(advanced_tabs) + self._setup_adv_tab_misc(advanced_tabs) + + options_layout.addWidget(self.advanced_group) + + # Right Panel (Output & Queue) + btn_layout = QHBoxLayout() + self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate") + self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue") + self.generate_btn.setEnabled(True) + self.add_to_queue_btn.setEnabled(False) + btn_layout.addWidget(self.generate_btn) + btn_layout.addWidget(self.add_to_queue_btn) + right_layout.addLayout(btn_layout) + + self.status_label = self.create_widget(QLabel, 'status_label', "Idle") + right_layout.addWidget(self.status_label) + self.progress_bar = self.create_widget(QProgressBar, 'progress_bar') + right_layout.addWidget(self.progress_bar) + + preview_group = self.create_widget(QGroupBox, 'preview_group', "Preview") + preview_group.setCheckable(True) + preview_group.setStyleSheet("QGroupBox { border: 1px solid #cccccc; }") + preview_group_layout = QVBoxLayout(preview_group) + + self.preview_image = self.create_widget(QLabel, 'preview_image', "") + self.preview_image.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.preview_image.setMinimumSize(200, 200) + + preview_group_layout.addWidget(self.preview_image) + right_layout.addWidget(preview_group) + + right_layout.addWidget(QLabel("Output:")) + self.output_gallery = self.create_widget(QListWidget, 'output_gallery') + right_layout.addWidget(self.output_gallery) + + right_layout.addWidget(QLabel("Queue:")) + self.queue_table = self.create_widget(QueueTableWidget, 'queue_table') + right_layout.addWidget(self.queue_table) + + queue_btn_layout = QHBoxLayout() + self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected") + self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue") + self.abort_btn = self.create_widget(QPushButton, 'abort_btn', "Abort") + queue_btn_layout.addWidget(self.remove_queue_btn) + queue_btn_layout.addWidget(self.clear_queue_btn) + queue_btn_layout.addWidget(self.abort_btn) + right_layout.addLayout(queue_btn_layout) + + def _setup_adv_tab_general(self, tabs): + tab = QWidget() + tabs.addTab(tab, "General") + layout = QFormLayout(tab) + self.widgets['adv_general_layout'] = layout + + guidance_group = QGroupBox("Guidance") + guidance_layout = self.create_widget(QFormLayout, 'guidance_layout', guidance_group) + guidance_layout.addRow("Guidance (CFG):", self._create_slider_with_label('guidance_scale', 10, 200, 5.0, 10.0, 1)) + + self.widgets['guidance_phases_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Guidance Phases:", self.create_widget(QComboBox, 'guidance_phases')) + + self.widgets['guidance2_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Guidance 2:", self._create_slider_with_label('guidance2_scale', 10, 200, 5.0, 10.0, 1)) + self.widgets['guidance3_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Guidance 3:", self._create_slider_with_label('guidance3_scale', 10, 200, 5.0, 10.0, 1)) + self.widgets['switch_thresh_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Switch Threshold:", self._create_slider_with_label('switch_threshold', 0, 1000, 0, 1.0, 0)) + layout.addRow(guidance_group) + + nag_group = self.create_widget(QGroupBox, 'nag_group', "NAG (Negative Adversarial Guidance)") + nag_layout = QFormLayout(nag_group) + nag_layout.addRow("NAG Scale:", self._create_slider_with_label('NAG_scale', 10, 200, 1.0, 10.0, 1)) + nag_layout.addRow("NAG Tau:", self._create_slider_with_label('NAG_tau', 10, 50, 3.5, 10.0, 1)) + nag_layout.addRow("NAG Alpha:", self._create_slider_with_label('NAG_alpha', 0, 20, 0.5, 10.0, 1)) + layout.addRow(nag_group) + + self.widgets['solver_row_container'] = QWidget() + solver_hbox = QHBoxLayout(self.widgets['solver_row_container']) + solver_hbox.setContentsMargins(0,0,0,0) + solver_hbox.addWidget(QLabel("Sampler Solver:")) + solver_hbox.addWidget(self.create_widget(QComboBox, 'sample_solver')) + layout.addRow(self.widgets['solver_row_container']) + + self.widgets['flow_shift_row_index'] = layout.rowCount() + layout.addRow("Shift Scale:", self._create_slider_with_label('flow_shift', 10, 250, 3.0, 10.0, 1)) + + self.widgets['audio_guidance_row_index'] = layout.rowCount() + layout.addRow("Audio Guidance:", self._create_slider_with_label('audio_guidance_scale', 10, 200, 4.0, 10.0, 1)) + + self.widgets['repeat_generation_row_index'] = layout.rowCount() + layout.addRow("Repeat Generations:", self._create_slider_with_label('repeat_generation', 1, 25, 1, 1.0, 0)) + + combo = self.create_widget(QComboBox, 'multi_images_gen_type') + combo.addItem("Generate all combinations", 0) + combo.addItem("Match images and texts", 1) + self.widgets['multi_images_gen_type_row_index'] = layout.rowCount() + layout.addRow("Multi-Image Mode:", combo) + + def _setup_adv_tab_loras(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Loras") + layout = QVBoxLayout(tab) + layout.addWidget(QLabel("Available Loras (Ctrl+Click to select multiple):")) + lora_list = self.create_widget(QListWidget, 'activated_loras') + lora_list.setSelectionMode(QListWidget.SelectionMode.MultiSelection) + layout.addWidget(lora_list) + layout.addWidget(QLabel("Loras Multipliers:")) + layout.addWidget(self.create_widget(QTextEdit, 'loras_multipliers')) + + def _setup_adv_tab_speed(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Speed") + layout = QFormLayout(tab) + combo = self.create_widget(QComboBox, 'skip_steps_cache_type') + combo.addItem("None", "") + combo.addItem("Tea Cache", "tea") + combo.addItem("Mag Cache", "mag") + layout.addRow("Cache Type:", combo) + + combo = self.create_widget(QComboBox, 'skip_steps_multiplier') + combo.addItem("x1.5 speed up", 1.5) + combo.addItem("x1.75 speed up", 1.75) + combo.addItem("x2.0 speed up", 2.0) + combo.addItem("x2.25 speed up", 2.25) + combo.addItem("x2.5 speed up", 2.5) + layout.addRow("Acceleration:", combo) + layout.addRow("Start %:", self._create_slider_with_label('skip_steps_start_step_perc', 0, 100, 0, 1.0, 0)) + + def _setup_adv_tab_postproc(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Post-Processing") + layout = QFormLayout(tab) + combo = self.create_widget(QComboBox, 'temporal_upsampling') + combo.addItem("Disabled", "") + combo.addItem("Rife x2 frames/s", "rife2") + combo.addItem("Rife x4 frames/s", "rife4") + layout.addRow("Temporal Upsampling:", combo) + + combo = self.create_widget(QComboBox, 'spatial_upsampling') + combo.addItem("Disabled", "") + combo.addItem("Lanczos x1.5", "lanczos1.5") + combo.addItem("Lanczos x2.0", "lanczos2") + layout.addRow("Spatial Upsampling:", combo) + + layout.addRow("Film Grain Intensity:", self._create_slider_with_label('film_grain_intensity', 0, 100, 0, 100.0, 2)) + layout.addRow("Film Grain Saturation:", self._create_slider_with_label('film_grain_saturation', 0, 100, 0.5, 100.0, 2)) + + def _setup_adv_tab_audio(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Audio") + layout = QFormLayout(tab) + combo = self.create_widget(QComboBox, 'MMAudio_setting') + combo.addItem("Disabled", 0) + combo.addItem("Enabled", 1) + layout.addRow("MMAudio:", combo) + layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_prompt', placeholderText="MMAudio Prompt")) + layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_neg_prompt', placeholderText="MMAudio Negative Prompt")) + layout.addRow(self._create_file_input('audio_source', "Custom Soundtrack")) + + def _setup_adv_tab_quality(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Quality") + layout = QVBoxLayout(tab) + + slg_group = self.create_widget(QGroupBox, 'slg_group', "Skip Layer Guidance") + slg_layout = QFormLayout(slg_group) + slg_combo = self.create_widget(QComboBox, 'slg_switch') + slg_combo.addItem("OFF", 0) + slg_combo.addItem("ON", 1) + slg_layout.addRow("Enable SLG:", slg_combo) + slg_layout.addRow("Start %:", self._create_slider_with_label('slg_start_perc', 0, 100, 10, 1.0, 0)) + slg_layout.addRow("End %:", self._create_slider_with_label('slg_end_perc', 0, 100, 90, 1.0, 0)) + layout.addWidget(slg_group) + + quality_form = QFormLayout() + self.widgets['quality_form_layout'] = quality_form + + apg_combo = self.create_widget(QComboBox, 'apg_switch') + apg_combo.addItem("OFF", 0) + apg_combo.addItem("ON", 1) + self.widgets['apg_switch_row_index'] = quality_form.rowCount() + quality_form.addRow("Adaptive Projected Guidance:", apg_combo) + + cfg_star_combo = self.create_widget(QComboBox, 'cfg_star_switch') + cfg_star_combo.addItem("OFF", 0) + cfg_star_combo.addItem("ON", 1) + self.widgets['cfg_star_switch_row_index'] = quality_form.rowCount() + quality_form.addRow("Classifier-Free Guidance Star:", cfg_star_combo) + + self.widgets['cfg_zero_step_row_index'] = quality_form.rowCount() + quality_form.addRow("CFG Zero below Layer:", self._create_slider_with_label('cfg_zero_step', -1, 39, -1, 1.0, 0)) + + combo = self.create_widget(QComboBox, 'min_frames_if_references') + combo.addItem("Disabled (1 frame)", 1) + combo.addItem("Generate 5 frames", 5) + combo.addItem("Generate 9 frames", 9) + combo.addItem("Generate 13 frames", 13) + combo.addItem("Generate 17 frames", 17) + self.widgets['min_frames_if_references_row_index'] = quality_form.rowCount() + quality_form.addRow("Min Frames for Quality:", combo) + layout.addLayout(quality_form) + + def _setup_adv_tab_sliding_window(self, tabs): + tab = QWidget() + self.widgets['sliding_window_tab_index'] = tabs.count() + tabs.addTab(tab, "Sliding Window") + layout = QFormLayout(tab) + + layout.addRow("Window Size:", self._create_slider_with_label('sliding_window_size', 5, 257, 129, 1.0, 0)) + layout.addRow("Overlap:", self._create_slider_with_label('sliding_window_overlap', 1, 97, 5, 1.0, 0)) + layout.addRow("Color Correction:", self._create_slider_with_label('sliding_window_color_correction_strength', 0, 100, 0, 100.0, 2)) + layout.addRow("Overlap Noise:", self._create_slider_with_label('sliding_window_overlap_noise', 0, 150, 20, 1.0, 0)) + layout.addRow("Discard Last Frames:", self._create_slider_with_label('sliding_window_discard_last_frames', 0, 20, 0, 1.0, 0)) + + def _setup_adv_tab_misc(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Misc") + layout = QFormLayout(tab) + self.widgets['misc_layout'] = layout + + riflex_combo = self.create_widget(QComboBox, 'RIFLEx_setting') + riflex_combo.addItem("Auto", 0) + riflex_combo.addItem("Always ON", 1) + riflex_combo.addItem("Always OFF", 2) + self.widgets['riflex_row_index'] = layout.rowCount() + layout.addRow("RIFLEx Setting:", riflex_combo) + + fps_combo = self.create_widget(QComboBox, 'force_fps') + layout.addRow("Force FPS:", fps_combo) + + profile_combo = self.create_widget(QComboBox, 'override_profile') + profile_combo.addItem("Default Profile", -1) + for text, val in wgp.memory_profile_choices: + profile_combo.addItem(text.split(':')[0], val) + layout.addRow("Override Memory Profile:", profile_combo) + + combo = self.create_widget(QComboBox, 'multi_prompts_gen_type') + combo.addItem("Generate new Video per line", 0) + combo.addItem("Use line for new Sliding Window", 1) + layout.addRow("Multi-Prompt Mode:", combo) + + def setup_config_tab(self): + config_tab = QWidget() + self.tabs.addTab(config_tab, "Configuration") + main_layout = QVBoxLayout(config_tab) + + self.config_status_label = QLabel("Apply changes for them to take effect. Some may require a restart.") + main_layout.addWidget(self.config_status_label) + + config_tabs = QTabWidget() + main_layout.addWidget(config_tabs) + + config_tabs.addTab(self._create_general_config_tab(), "General") + config_tabs.addTab(self._create_performance_config_tab(), "Performance") + config_tabs.addTab(self._create_extensions_config_tab(), "Extensions") + config_tabs.addTab(self._create_outputs_config_tab(), "Outputs") + config_tabs.addTab(self._create_notifications_config_tab(), "Notifications") + + self.apply_config_btn = QPushButton("Apply Changes") + self.apply_config_btn.clicked.connect(self._on_apply_config_changes) + main_layout.addWidget(self.apply_config_btn) + + def _create_scrollable_form_tab(self): + tab_widget = QWidget() + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + layout = QVBoxLayout(tab_widget) + layout.addWidget(scroll_area) + + content_widget = QWidget() + form_layout = QFormLayout(content_widget) + scroll_area.setWidget(content_widget) + + return tab_widget, form_layout + + def _create_config_combo(self, form_layout, label, key, choices, default_value): + combo = QComboBox() + for text, data in choices: + combo.addItem(text, data) + index = combo.findData(wgp.server_config.get(key, default_value)) + if index != -1: combo.setCurrentIndex(index) + self.widgets[f'config_{key}'] = combo + form_layout.addRow(label, combo) + + def _create_config_slider(self, form_layout, label, key, min_val, max_val, default_value, step=1): + container = QWidget() + hbox = QHBoxLayout(container) + hbox.setContentsMargins(0,0,0,0) + + slider = QSlider(Qt.Orientation.Horizontal) + slider.setRange(min_val, max_val) + slider.setSingleStep(step) + slider.setValue(wgp.server_config.get(key, default_value)) + + value_label = QLabel(str(slider.value())) + value_label.setMinimumWidth(40) + slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v))) + + hbox.addWidget(slider) + hbox.addWidget(value_label) + + self.widgets[f'config_{key}'] = slider + form_layout.addRow(label, container) + + def _create_config_checklist(self, form_layout, label, key, choices, default_value): + list_widget = QListWidget() + list_widget.setMinimumHeight(100) + current_values = wgp.server_config.get(key, default_value) + for text, data in choices: + item = QListWidgetItem(text) + item.setData(Qt.ItemDataRole.UserRole, data) + item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) + if data in current_values: + item.setCheckState(Qt.CheckState.Checked) + else: + item.setCheckState(Qt.CheckState.Unchecked) + list_widget.addItem(item) + self.widgets[f'config_{key}'] = list_widget + form_layout.addRow(label, list_widget) + + def _create_config_textbox(self, form_layout, label, key, default_value, multi_line=False): + if multi_line: + textbox = QTextEdit(default_value) + textbox.setAcceptRichText(False) + else: + textbox = QLineEdit(default_value) + self.widgets[f'config_{key}'] = textbox + form_layout.addRow(label, textbox) + + def _create_general_config_tab(self): + tab, form = self._create_scrollable_form_tab() + + _, _, dropdown_choices = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False) + self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, wgp.transformer_types) + + self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [ + ("Two Levels (Family > Model)", 0), + ("Three Levels (Family > Base > Finetune)", 1) + ], 1) + + self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [ + ("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), + ("Dimensions are Output Width/Height (Cropped)", 2)], 0) + + self._create_config_combo(form, "Attention Type:", "attention_mode", [ + ("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), + ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto") + + self._create_config_combo(form, "Metadata Handling:", "metadata_type", [ + ("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata") + + self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [ + ("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], []) + + self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [ + ("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), + ("Keep last 20", 20), ("Keep last 30", 30)], 5) + + self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0) + + self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", + [(f"x{i}", i) for i in range(1, 8)], 1) + + checkpoints_paths_text = "\n".join(wgp.server_config.get("checkpoints_paths", wgp.fl.default_checkpoints_paths)) + checkpoints_textbox = QTextEdit() + checkpoints_textbox.setPlainText(checkpoints_paths_text) + checkpoints_textbox.setAcceptRichText(False) + checkpoints_textbox.setMinimumHeight(60) + self.widgets['config_checkpoints_paths'] = checkpoints_textbox + form.addRow("Checkpoints Paths:", checkpoints_textbox) + + self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default") + + return tab + + def _create_performance_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Transformer Quantization:", "transformer_quantization", [("Scaled Int8 (recommended)", "int8"), ("16-bit (no quantization)", "bf16")], "int8") + self._create_config_combo(form, "Transformer Data Type:", "transformer_dtype_policy", [("Best Supported by Hardware", ""), ("FP16", "fp16"), ("BF16", "bf16")], "") + self._create_config_combo(form, "Transformer Calculation:", "mixed_precision", [("16-bit only", "0"), ("Mixed 16/32-bit (better quality)", "1")], "0") + self._create_config_combo(form, "Text Encoder:", "text_encoder_quantization", [("16-bit (more RAM, better quality)", "bf16"), ("8-bit (less RAM)", "int8")], "int8") + self._create_config_combo(form, "VAE Precision:", "vae_precision", [("16-bit (faster, less VRAM)", "16"), ("32-bit (slower, better quality)", "32")], "16") + self._create_config_combo(form, "Compile Transformer:", "compile", [("On (requires Triton)", "transformer"), ("Off", "")], "") + self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl") + self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0) + self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1) + self._create_config_combo(form, "Memory Profile:", "profile", wgp.memory_profile_choices, wgp.profile_type.LowRAM_LowVRAM) + self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100) + + release_ram_btn = QPushButton("Force Release Models from RAM") + release_ram_btn.clicked.connect(self._on_release_ram) + form.addRow(release_ram_btn) + + return tab + + def _create_extensions_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Prompt Enhancer:", "enhancer_enabled", [("Off", 0), ("Florence 2 + Llama 3.2", 1), ("Florence 2 + Joy Caption (uncensored)", 2)], 0) + self._create_config_combo(form, "Enhancer Mode:", "enhancer_mode", [("Automatic on Generate", 0), ("On Demand Only", 1)], 0) + self._create_config_combo(form, "MMAudio:", "mmaudio_enabled", [("Off", 0), ("Enabled (unloaded after use)", 1), ("Enabled (persistent in RAM)", 2)], 0) + return tab + + def _create_outputs_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Video Codec:", "video_output_codec", [("x265 Balanced", 'libx265_28'), ("x264 Balanced", 'libx264_8'), ("x265 High Quality", 'libx265_8'), ("x264 High Quality", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], 'libx264_8') + self._create_config_combo(form, "Image Codec:", "image_output_codec", [("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], 'jpeg_95') + self._create_config_textbox(form, "Video Output Folder:", "save_path", "outputs") + self._create_config_textbox(form, "Image Output Folder:", "image_save_path", "outputs") + return tab + + def _create_notifications_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Notification Sound:", "notification_sound_enabled", [("On", 1), ("Off", 0)], 0) + self._create_config_slider(form, "Sound Volume:", "notification_sound_volume", 0, 100, 50, 5) + return tab + + def init_wgp_state(self): + initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type) + all_models, _, _ = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False) + all_model_ids = [m[1] for m in all_models] + if initial_model not in all_model_ids: + initial_model = wgp.transformer_type + + state_dict = {} + state_dict["model_filename"] = wgp.get_model_filename(initial_model, wgp.transformer_quantization, wgp.transformer_dtype_policy) + state_dict["model_type"] = initial_model + state_dict["advanced"] = wgp.advanced + state_dict["last_model_per_family"] = wgp.server_config.get("last_model_per_family", {}) + state_dict["last_model_per_type"] = wgp.server_config.get("last_model_per_type", {}) + state_dict["last_resolution_per_group"] = wgp.server_config.get("last_resolution_per_group", {}) + state_dict["gen"] = {"queue": []} + + self.state = state_dict + self.advanced_group.setChecked(wgp.advanced) + + self.update_model_dropdowns(initial_model) + self.refresh_ui_from_model_change(initial_model) + self._update_input_visibility() # Set initial visibility + + def update_model_dropdowns(self, current_model_type): + family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type) + + self.widgets['model_family'].blockSignals(True) + self.widgets['model_base_type_choice'].blockSignals(True) + self.widgets['model_choice'].blockSignals(True) + + self.widgets['model_family'].clear() + if family_mock.choices: + for display_name, internal_key in family_mock.choices: + self.widgets['model_family'].addItem(display_name, internal_key) + index = self.widgets['model_family'].findData(family_mock.value) + if index != -1: self.widgets['model_family'].setCurrentIndex(index) + + self.widgets['model_base_type_choice'].clear() + if base_type_mock.choices: + for label, value in base_type_mock.choices: + self.widgets['model_base_type_choice'].addItem(label, value) + index = self.widgets['model_base_type_choice'].findData(base_type_mock.value) + if index != -1: self.widgets['model_base_type_choice'].setCurrentIndex(index) + self.widgets['model_base_type_choice'].setVisible(base_type_mock.kwargs.get('visible', True)) + + self.widgets['model_choice'].clear() + if choice_mock.choices: + for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) + index = self.widgets['model_choice'].findData(choice_mock.value) + if index != -1: self.widgets['model_choice'].setCurrentIndex(index) + self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True)) + + self.widgets['model_family'].blockSignals(False) + self.widgets['model_base_type_choice'].blockSignals(False) + self.widgets['model_choice'].blockSignals(False) + + def refresh_ui_from_model_change(self, model_type): + """Update UI controls with default settings when the model is changed.""" + self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode)) + ui_defaults = wgp.get_default_settings(model_type) + wgp.set_model_settings(self.state, model_type, ui_defaults) + + model_def = wgp.get_model_def(model_type) + base_model_type = wgp.get_base_model_type(model_type) + model_filename = self.state.get('model_filename', '') + + image_outputs = model_def.get("image_outputs", False) + vace = wgp.test_vace_module(model_type) + t2v = base_model_type in ['t2v', 't2v_2_2'] + i2v = wgp.test_class_i2v(model_type) + fantasy = base_model_type in ["fantasy"] + multitalk = model_def.get("multitalk_class", False) + any_audio_guidance = fantasy or multitalk + sliding_window_enabled = wgp.test_any_sliding_window(model_type) + recammaster = base_model_type in ["recam_1.3B"] + ltxv = "ltxv" in model_filename + diffusion_forcing = "diffusion_forcing" in model_filename + any_skip_layer_guidance = model_def.get("skip_layer_guidance", False) + any_cfg_zero = model_def.get("cfg_zero", False) + any_cfg_star = model_def.get("cfg_star", False) + any_apg = model_def.get("adaptive_projected_guidance", False) + v2i_switch_supported = model_def.get("v2i_switch_supported", False) + + self._update_generation_mode_visibility(model_def) + + for widget in self.widgets.values(): + if hasattr(widget, 'blockSignals'): widget.blockSignals(True) + + self.widgets['prompt'].setText(ui_defaults.get("prompt", "")) + self.widgets['negative_prompt'].setText(ui_defaults.get("negative_prompt", "")) + self.widgets['seed'].setText(str(ui_defaults.get("seed", -1))) + + video_length_val = ui_defaults.get("video_length", 81) + self.widgets['video_length'].setValue(video_length_val) + self.widgets['video_length_label'].setText(str(video_length_val)) + + steps_val = ui_defaults.get("num_inference_steps", 30) + self.widgets['num_inference_steps'].setValue(steps_val) + self.widgets['num_inference_steps_label'].setText(str(steps_val)) + + self.widgets['resolution_group'].blockSignals(True) + self.widgets['resolution'].blockSignals(True) + + current_res_choice = ui_defaults.get("resolution") + model_resolutions = model_def.get("resolutions", None) + self.full_resolution_choices, current_res_choice = wgp.get_resolution_choices(current_res_choice, model_resolutions) + available_groups, selected_group_resolutions, selected_group = wgp.group_resolutions(model_def, self.full_resolution_choices, current_res_choice) + + self.widgets['resolution_group'].clear() + self.widgets['resolution_group'].addItems(available_groups) + group_index = self.widgets['resolution_group'].findText(selected_group) + if group_index != -1: + self.widgets['resolution_group'].setCurrentIndex(group_index) + + self.widgets['resolution'].clear() + for label, value in selected_group_resolutions: + self.widgets['resolution'].addItem(label, value) + res_index = self.widgets['resolution'].findData(current_res_choice) + if res_index != -1: + self.widgets['resolution'].setCurrentIndex(res_index) + + self.widgets['resolution_group'].blockSignals(False) + self.widgets['resolution'].blockSignals(False) + + + for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']: + if name in self.widgets: + self.widgets[name].clear() + + guidance_layout = self.widgets['guidance_layout'] + guidance_max = model_def.get("guidance_max_phases", 1) + guidance_layout.setRowVisible(self.widgets['guidance_phases_row_index'], guidance_max > 1) + + adv_general_layout = self.widgets['adv_general_layout'] + adv_general_layout.setRowVisible(self.widgets['flow_shift_row_index'], not image_outputs) + adv_general_layout.setRowVisible(self.widgets['audio_guidance_row_index'], any_audio_guidance) + adv_general_layout.setRowVisible(self.widgets['repeat_generation_row_index'], not image_outputs) + adv_general_layout.setRowVisible(self.widgets['multi_images_gen_type_row_index'], i2v) + + self.widgets['slg_group'].setVisible(any_skip_layer_guidance) + quality_form_layout = self.widgets['quality_form_layout'] + quality_form_layout.setRowVisible(self.widgets['apg_switch_row_index'], any_apg) + quality_form_layout.setRowVisible(self.widgets['cfg_star_switch_row_index'], any_cfg_star) + quality_form_layout.setRowVisible(self.widgets['cfg_zero_step_row_index'], any_cfg_zero) + quality_form_layout.setRowVisible(self.widgets['min_frames_if_references_row_index'], v2i_switch_supported and image_outputs) + + self.widgets['advanced_tabs'].setTabVisible(self.widgets['sliding_window_tab_index'], sliding_window_enabled and not image_outputs) + + misc_layout = self.widgets['misc_layout'] + misc_layout.setRowVisible(self.widgets['riflex_row_index'], not (recammaster or ltxv or diffusion_forcing)) + + + index = self.widgets['multi_images_gen_type'].findData(ui_defaults.get('multi_images_gen_type', 0)) + if index != -1: self.widgets['multi_images_gen_type'].setCurrentIndex(index) + + guidance_val = ui_defaults.get("guidance_scale", 5.0) + self.widgets['guidance_scale'].setValue(int(guidance_val * 10)) + self.widgets['guidance_scale_label'].setText(f"{guidance_val:.1f}") + + guidance2_val = ui_defaults.get("guidance2_scale", 5.0) + self.widgets['guidance2_scale'].setValue(int(guidance2_val * 10)) + self.widgets['guidance2_scale_label'].setText(f"{guidance2_val:.1f}") + + guidance3_val = ui_defaults.get("guidance3_scale", 5.0) + self.widgets['guidance3_scale'].setValue(int(guidance3_val * 10)) + self.widgets['guidance3_scale_label'].setText(f"{guidance3_val:.1f}") + self.widgets['guidance_phases'].clear() + + if guidance_max >= 1: self.widgets['guidance_phases'].addItem("One Phase", 1) + if guidance_max >= 2: self.widgets['guidance_phases'].addItem("Two Phases", 2) + if guidance_max >= 3: self.widgets['guidance_phases'].addItem("Three Phases", 3) + + index = self.widgets['guidance_phases'].findData(ui_defaults.get("guidance_phases", 1)) + if index != -1: self.widgets['guidance_phases'].setCurrentIndex(index) + + switch_thresh_val = ui_defaults.get("switch_threshold", 0) + self.widgets['switch_threshold'].setValue(switch_thresh_val) + self.widgets['switch_threshold_label'].setText(str(switch_thresh_val)) + + nag_scale_val = ui_defaults.get('NAG_scale', 1.0) + self.widgets['NAG_scale'].setValue(int(nag_scale_val * 10)) + self.widgets['NAG_scale_label'].setText(f"{nag_scale_val:.1f}") + + nag_tau_val = ui_defaults.get('NAG_tau', 3.5) + self.widgets['NAG_tau'].setValue(int(nag_tau_val * 10)) + self.widgets['NAG_tau_label'].setText(f"{nag_tau_val:.1f}") + + nag_alpha_val = ui_defaults.get('NAG_alpha', 0.5) + self.widgets['NAG_alpha'].setValue(int(nag_alpha_val * 10)) + self.widgets['NAG_alpha_label'].setText(f"{nag_alpha_val:.1f}") + + self.widgets['nag_group'].setVisible(vace or t2v or i2v) + + self.widgets['sample_solver'].clear() + sampler_choices = model_def.get("sample_solvers", []) + self.widgets['solver_row_container'].setVisible(bool(sampler_choices)) + if sampler_choices: + for label, value in sampler_choices: self.widgets['sample_solver'].addItem(label, value) + solver_val = ui_defaults.get('sample_solver', sampler_choices[0][1]) + index = self.widgets['sample_solver'].findData(solver_val) + if index != -1: self.widgets['sample_solver'].setCurrentIndex(index) + + flow_val = ui_defaults.get("flow_shift", 3.0) + self.widgets['flow_shift'].setValue(int(flow_val * 10)) + self.widgets['flow_shift_label'].setText(f"{flow_val:.1f}") + + audio_guidance_val = ui_defaults.get("audio_guidance_scale", 4.0) + self.widgets['audio_guidance_scale'].setValue(int(audio_guidance_val * 10)) + self.widgets['audio_guidance_scale_label'].setText(f"{audio_guidance_val:.1f}") + + repeat_val = ui_defaults.get("repeat_generation", 1) + self.widgets['repeat_generation'].setValue(repeat_val) + self.widgets['repeat_generation_label'].setText(str(repeat_val)) + + available_loras, _, _, _, _, _ = wgp.setup_loras(model_type, None, wgp.get_lora_dir(model_type), "") + self.state['loras'] = available_loras + self.lora_map = {os.path.basename(p): p for p in available_loras} + lora_list_widget = self.widgets['activated_loras'] + lora_list_widget.clear() + lora_list_widget.addItems(sorted(self.lora_map.keys())) + selected_loras = ui_defaults.get('activated_loras', []) + for i in range(lora_list_widget.count()): + item = lora_list_widget.item(i) + is_selected = any(item.text() == os.path.basename(p) for p in selected_loras) + if is_selected: + item.setSelected(True) + self.widgets['loras_multipliers'].setText(ui_defaults.get('loras_multipliers', '')) + + skip_cache_val = ui_defaults.get('skip_steps_cache_type', "") + index = self.widgets['skip_steps_cache_type'].findData(skip_cache_val) + if index != -1: self.widgets['skip_steps_cache_type'].setCurrentIndex(index) + + skip_mult = ui_defaults.get('skip_steps_multiplier', 1.5) + index = self.widgets['skip_steps_multiplier'].findData(skip_mult) + if index != -1: self.widgets['skip_steps_multiplier'].setCurrentIndex(index) + + skip_perc_val = ui_defaults.get('skip_steps_start_step_perc', 0) + self.widgets['skip_steps_start_step_perc'].setValue(skip_perc_val) + self.widgets['skip_steps_start_step_perc_label'].setText(str(skip_perc_val)) + + temp_up_val = ui_defaults.get('temporal_upsampling', "") + index = self.widgets['temporal_upsampling'].findData(temp_up_val) + if index != -1: self.widgets['temporal_upsampling'].setCurrentIndex(index) + + spat_up_val = ui_defaults.get('spatial_upsampling', "") + index = self.widgets['spatial_upsampling'].findData(spat_up_val) + if index != -1: self.widgets['spatial_upsampling'].setCurrentIndex(index) + + film_grain_i = ui_defaults.get('film_grain_intensity', 0) + self.widgets['film_grain_intensity'].setValue(int(film_grain_i * 100)) + self.widgets['film_grain_intensity_label'].setText(f"{film_grain_i:.2f}") + + film_grain_s = ui_defaults.get('film_grain_saturation', 0.5) + self.widgets['film_grain_saturation'].setValue(int(film_grain_s * 100)) + self.widgets['film_grain_saturation_label'].setText(f"{film_grain_s:.2f}") + + self.widgets['MMAudio_setting'].setCurrentIndex(ui_defaults.get('MMAudio_setting', 0)) + self.widgets['MMAudio_prompt'].setText(ui_defaults.get('MMAudio_prompt', '')) + self.widgets['MMAudio_neg_prompt'].setText(ui_defaults.get('MMAudio_neg_prompt', '')) + + self.widgets['slg_switch'].setCurrentIndex(ui_defaults.get('slg_switch', 0)) + slg_start_val = ui_defaults.get('slg_start_perc', 10) + self.widgets['slg_start_perc'].setValue(slg_start_val) + self.widgets['slg_start_perc_label'].setText(str(slg_start_val)) + slg_end_val = ui_defaults.get('slg_end_perc', 90) + self.widgets['slg_end_perc'].setValue(slg_end_val) + self.widgets['slg_end_perc_label'].setText(str(slg_end_val)) + + self.widgets['apg_switch'].setCurrentIndex(ui_defaults.get('apg_switch', 0)) + self.widgets['cfg_star_switch'].setCurrentIndex(ui_defaults.get('cfg_star_switch', 0)) + + cfg_zero_val = ui_defaults.get('cfg_zero_step', -1) + self.widgets['cfg_zero_step'].setValue(cfg_zero_val) + self.widgets['cfg_zero_step_label'].setText(str(cfg_zero_val)) + + min_frames_val = ui_defaults.get('min_frames_if_references', 1) + index = self.widgets['min_frames_if_references'].findData(min_frames_val) + if index != -1: self.widgets['min_frames_if_references'].setCurrentIndex(index) + + self.widgets['RIFLEx_setting'].setCurrentIndex(ui_defaults.get('RIFLEx_setting', 0)) + + fps = wgp.get_model_fps(model_type) + force_fps_choices = [ + (f"Model Default ({fps} fps)", ""), ("Auto", "auto"), ("Control Video fps", "control"), + ("Source Video fps", "source"), ("15", "15"), ("16", "16"), ("23", "23"), + ("24", "24"), ("25", "25"), ("30", "30") + ] + self.widgets['force_fps'].clear() + for label, value in force_fps_choices: self.widgets['force_fps'].addItem(label, value) + force_fps_val = ui_defaults.get('force_fps', "") + index = self.widgets['force_fps'].findData(force_fps_val) + if index != -1: self.widgets['force_fps'].setCurrentIndex(index) + + override_prof_val = ui_defaults.get('override_profile', -1) + index = self.widgets['override_profile'].findData(override_prof_val) + if index != -1: self.widgets['override_profile'].setCurrentIndex(index) + + self.widgets['multi_prompts_gen_type'].setCurrentIndex(ui_defaults.get('multi_prompts_gen_type', 0)) + + denoising_val = ui_defaults.get("denoising_strength", 0.5) + self.widgets['denoising_strength'].setValue(int(denoising_val * 100)) + self.widgets['denoising_strength_label'].setText(f"{denoising_val:.2f}") + + sw_size = ui_defaults.get("sliding_window_size", 129) + self.widgets['sliding_window_size'].setValue(sw_size) + self.widgets['sliding_window_size_label'].setText(str(sw_size)) + + sw_overlap = ui_defaults.get("sliding_window_overlap", 5) + self.widgets['sliding_window_overlap'].setValue(sw_overlap) + self.widgets['sliding_window_overlap_label'].setText(str(sw_overlap)) + + sw_color = ui_defaults.get("sliding_window_color_correction_strength", 0) + self.widgets['sliding_window_color_correction_strength'].setValue(int(sw_color * 100)) + self.widgets['sliding_window_color_correction_strength_label'].setText(f"{sw_color:.2f}") + + sw_noise = ui_defaults.get("sliding_window_overlap_noise", 20) + self.widgets['sliding_window_overlap_noise'].setValue(sw_noise) + self.widgets['sliding_window_overlap_noise_label'].setText(str(sw_noise)) + + sw_discard = ui_defaults.get("sliding_window_discard_last_frames", 0) + self.widgets['sliding_window_discard_last_frames'].setValue(sw_discard) + self.widgets['sliding_window_discard_last_frames_label'].setText(str(sw_discard)) + + for widget in self.widgets.values(): + if hasattr(widget, 'blockSignals'): widget.blockSignals(False) + self._update_dynamic_ui() + self._update_input_visibility() + + def _update_dynamic_ui(self): + """Update UI visibility based on current selections.""" + phases = self.widgets['guidance_phases'].currentData() or 1 + guidance_layout = self.widgets['guidance_layout'] + guidance_layout.setRowVisible(self.widgets['guidance2_row_index'], phases >= 2) + guidance_layout.setRowVisible(self.widgets['guidance3_row_index'], phases >= 3) + guidance_layout.setRowVisible(self.widgets['switch_thresh_row_index'], phases >= 2) + + def _update_generation_mode_visibility(self, model_def): + """Shows/hides the main generation mode options based on the selected model.""" + allowed = model_def.get("image_prompt_types_allowed", "") + + choices = [] + if "T" in allowed or not allowed: + choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T")) + if "S" in allowed: + choices.append(("Start Video with Image", "S")) + if "V" in allowed: + choices.append(("Continue Video", "V")) + if "L" in allowed: + choices.append(("Continue Last Video", "L")) + + button_map = { + "T": self.widgets['mode_t'], + "S": self.widgets['mode_s'], + "V": self.widgets['mode_v'], + "L": self.widgets['mode_l'], + } + + for btn in button_map.values(): + btn.setVisible(False) + + allowed_values = [c[1] for c in choices] + for label, value in choices: + if value in button_map: + btn = button_map[value] + btn.setText(label) + btn.setVisible(True) + + current_checked_value = None + for value, btn in button_map.items(): + if btn.isChecked(): + current_checked_value = value + break + + # If the currently selected mode is now hidden, reset to a visible default + if current_checked_value is None or not button_map[current_checked_value].isVisible(): + if allowed_values: + button_map[allowed_values[0]].setChecked(True) + + + end_image_visible = "E" in allowed + self.widgets['image_end_checkbox'].setVisible(end_image_visible) + if not end_image_visible: + self.widgets['image_end_checkbox'].setChecked(False) + + # Control Video Checkbox (Based on model_def.get("guide_preprocessing")) + control_video_visible = model_def.get("guide_preprocessing") is not None + self.widgets['control_video_checkbox'].setVisible(control_video_visible) + if not control_video_visible: + self.widgets['control_video_checkbox'].setChecked(False) + + # Reference Image Checkbox (Based on model_def.get("image_ref_choices")) + ref_image_visible = model_def.get("image_ref_choices") is not None + self.widgets['ref_image_checkbox'].setVisible(ref_image_visible) + if not ref_image_visible: + self.widgets['ref_image_checkbox'].setChecked(False) + + + def _update_input_visibility(self): + """Shows/hides input fields based on the selected generation mode.""" + is_s_mode = self.widgets['mode_s'].isChecked() + is_v_mode = self.widgets['mode_v'].isChecked() + is_l_mode = self.widgets['mode_l'].isChecked() + + use_end = self.widgets['image_end_checkbox'].isChecked() and self.widgets['image_end_checkbox'].isVisible() + use_control = self.widgets['control_video_checkbox'].isChecked() and self.widgets['control_video_checkbox'].isVisible() + use_ref = self.widgets['ref_image_checkbox'].isChecked() and self.widgets['ref_image_checkbox'].isVisible() + + self.widgets['image_start_container'].setVisible(is_s_mode) + self.widgets['video_source_container'].setVisible(is_v_mode) + + end_checkbox_enabled = is_s_mode or is_v_mode or is_l_mode + self.widgets['image_end_checkbox'].setEnabled(end_checkbox_enabled) + self.widgets['image_end_container'].setVisible(use_end and end_checkbox_enabled) + + self.widgets['video_guide_container'].setVisible(use_control) + self.widgets['video_mask_container'].setVisible(use_control) + self.widgets['image_refs_container'].setVisible(use_ref) + + def connect_signals(self): + self.widgets['model_family'].currentIndexChanged.connect(self._on_family_changed) + self.widgets['model_base_type_choice'].currentIndexChanged.connect(self._on_base_type_changed) + self.widgets['model_choice'].currentIndexChanged.connect(self._on_model_changed) + self.widgets['resolution_group'].currentIndexChanged.connect(self._on_resolution_group_changed) + self.widgets['guidance_phases'].currentIndexChanged.connect(self._update_dynamic_ui) + + self.widgets['mode_t'].toggled.connect(self._update_input_visibility) + self.widgets['mode_s'].toggled.connect(self._update_input_visibility) + self.widgets['mode_v'].toggled.connect(self._update_input_visibility) + self.widgets['mode_l'].toggled.connect(self._update_input_visibility) + + self.widgets['image_end_checkbox'].toggled.connect(self._update_input_visibility) + self.widgets['control_video_checkbox'].toggled.connect(self._update_input_visibility) + self.widgets['ref_image_checkbox'].toggled.connect(self._update_input_visibility) + self.widgets['preview_group'].toggled.connect(self._on_preview_toggled) + + self.generate_btn.clicked.connect(self._on_generate) + self.add_to_queue_btn.clicked.connect(self._on_add_to_queue) + self.remove_queue_btn.clicked.connect(self._on_remove_selected_from_queue) + self.clear_queue_btn.clicked.connect(self._on_clear_queue) + self.abort_btn.clicked.connect(self._on_abort) + self.queue_table.rowsMoved.connect(self._on_queue_rows_moved) + + def apply_initial_config(self): + is_visible = main_config.get('preview_visible', True) + self.widgets['preview_group'].setChecked(is_visible) + self.widgets['preview_image'].setVisible(is_visible) + + def _on_preview_toggled(self, checked): + self.widgets['preview_image'].setVisible(checked) + main_config['preview_visible'] = checked + save_main_config() + + def _on_family_changed(self, index): + family = self.widgets['model_family'].currentData() + if not family or not self.state: return + + base_type_mock, choice_mock = wgp.change_model_family(self.state, family) + + self.widgets['model_base_type_choice'].blockSignals(True) + self.widgets['model_base_type_choice'].clear() + if base_type_mock.choices: + for label, value in base_type_mock.choices: + self.widgets['model_base_type_choice'].addItem(label, value) + index = self.widgets['model_base_type_choice'].findData(base_type_mock.value) + if index != -1: + self.widgets['model_base_type_choice'].setCurrentIndex(index) + self.widgets['model_base_type_choice'].setVisible(base_type_mock.kwargs.get('visible', True)) + self.widgets['model_base_type_choice'].blockSignals(False) + + self.widgets['model_choice'].blockSignals(True) + self.widgets['model_choice'].clear() + if choice_mock.choices: + for label, value in choice_mock.choices: + self.widgets['model_choice'].addItem(label, value) + index = self.widgets['model_choice'].findData(choice_mock.value) + if index != -1: + self.widgets['model_choice'].setCurrentIndex(index) + self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True)) + self.widgets['model_choice'].blockSignals(False) + + self._on_model_changed() + + def _on_base_type_changed(self, index): + family = self.widgets['model_family'].currentData() + base_type = self.widgets['model_base_type_choice'].currentData() + if not family or not base_type or not self.state: return + + base_type_mock, choice_mock = wgp.change_model_base_types(self.state, family, base_type) + + self.widgets['model_choice'].blockSignals(True) + self.widgets['model_choice'].clear() + if choice_mock.choices: + for label, value in choice_mock.choices: + self.widgets['model_choice'].addItem(label, value) + index = self.widgets['model_choice'].findData(choice_mock.value) + if index != -1: + self.widgets['model_choice'].setCurrentIndex(index) + self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True)) + self.widgets['model_choice'].blockSignals(False) + + self._on_model_changed() + + def _on_model_changed(self): + model_type = self.widgets['model_choice'].currentData() + if not model_type or model_type == self.state['model_type']: return + wgp.change_model(self.state, model_type) + self.refresh_ui_from_model_change(model_type) + + def _on_resolution_group_changed(self): + selected_group = self.widgets['resolution_group'].currentText() + if not selected_group or not hasattr(self, 'full_resolution_choices'): + return + + model_type = self.state['model_type'] + model_def = wgp.get_model_def(model_type) + model_resolutions = model_def.get("resolutions", None) + + group_resolution_choices = [] + if model_resolutions is None: + group_resolution_choices = [res for res in self.full_resolution_choices if wgp.categorize_resolution(res[1]) == selected_group] + else: + return + + last_resolution_per_group = self.state.get("last_resolution_per_group", {}) + last_resolution = last_resolution_per_group.get(selected_group, "") + + is_last_res_valid = any(last_resolution == res[1] for res in group_resolution_choices) + if not is_last_res_valid and group_resolution_choices: + last_resolution = group_resolution_choices[0][1] + + self.widgets['resolution'].blockSignals(True) + self.widgets['resolution'].clear() + for label, value in group_resolution_choices: + self.widgets['resolution'].addItem(label, value) + + index = self.widgets['resolution'].findData(last_resolution) + if index != -1: + self.widgets['resolution'].setCurrentIndex(index) + self.widgets['resolution'].blockSignals(False) + + def collect_inputs(self): + """Gather all settings from UI widgets into a dictionary.""" + # Start with all possible defaults. This dictionary will be modified and returned. + full_inputs = wgp.get_current_model_settings(self.state).copy() + + # Add dummy/default values for UI elements present in Gradio but not yet in PyQt. + # These are expected by the backend logic. + full_inputs['lset_name'] = "" + # The PyQt UI is focused on generating videos, so image_mode is 0. + full_inputs['image_mode'] = 0 + + # Defensively initialize keys that are accessed directly in wgp.py but may not + # be in the saved model settings or fully implemented in the UI yet. + # This prevents both KeyErrors and TypeErrors for missing arguments. + expected_keys = { + "audio_guide": None, "audio_guide2": None, "image_guide": None, + "image_mask": None, "speakers_locations": "", "frames_positions": "", + "keep_frames_video_guide": "", "keep_frames_video_source": "", + "video_guide_outpainting": "", "switch_threshold2": 0, + "model_switch_phase": 1, "batch_size": 1, + "control_net_weight_alt": 1.0, + "image_refs_relative_size": 50, + } + for key, default_value in expected_keys.items(): + if key not in full_inputs: + full_inputs[key] = default_value + + # Overwrite defaults with values from the PyQt UI widgets + full_inputs['prompt'] = self.widgets['prompt'].toPlainText() + full_inputs['negative_prompt'] = self.widgets['negative_prompt'].toPlainText() + full_inputs['resolution'] = self.widgets['resolution'].currentData() + full_inputs['video_length'] = self.widgets['video_length'].value() + full_inputs['num_inference_steps'] = self.widgets['num_inference_steps'].value() + full_inputs['seed'] = int(self.widgets['seed'].text()) + + # Build prompt_type strings based on mode selections + image_prompt_type = "" + video_prompt_type = "" + + if self.widgets['mode_s'].isChecked(): + image_prompt_type = 'S' + elif self.widgets['mode_v'].isChecked(): + image_prompt_type = 'V' + elif self.widgets['mode_l'].isChecked(): + image_prompt_type = 'L' + else: # mode_t is checked + image_prompt_type = '' + + if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): + image_prompt_type += 'E' + + if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): + video_prompt_type += 'V' # This 'V' is for Control Video (V2V) + if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): + video_prompt_type += 'I' # 'I' for Reference Image + + full_inputs['image_prompt_type'] = image_prompt_type + full_inputs['video_prompt_type'] = video_prompt_type + + # File Inputs + for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']: + if name in self.widgets: + path = self.widgets[name].text() + full_inputs[name] = path if path else None + + paths = self.widgets['image_refs'].text().split(';') + full_inputs['image_refs'] = [p.strip() for p in paths if p.strip()] if paths and paths[0] else None + + full_inputs['denoising_strength'] = self.widgets['denoising_strength'].value() / 100.0 + + if self.advanced_group.isChecked(): + full_inputs['guidance_scale'] = self.widgets['guidance_scale'].value() / 10.0 + full_inputs['guidance_phases'] = self.widgets['guidance_phases'].currentData() + full_inputs['guidance2_scale'] = self.widgets['guidance2_scale'].value() / 10.0 + full_inputs['guidance3_scale'] = self.widgets['guidance3_scale'].value() / 10.0 + full_inputs['switch_threshold'] = self.widgets['switch_threshold'].value() + + full_inputs['NAG_scale'] = self.widgets['NAG_scale'].value() / 10.0 + full_inputs['NAG_tau'] = self.widgets['NAG_tau'].value() / 10.0 + full_inputs['NAG_alpha'] = self.widgets['NAG_alpha'].value() / 10.0 + + full_inputs['sample_solver'] = self.widgets['sample_solver'].currentData() + full_inputs['flow_shift'] = self.widgets['flow_shift'].value() / 10.0 + full_inputs['audio_guidance_scale'] = self.widgets['audio_guidance_scale'].value() / 10.0 + full_inputs['repeat_generation'] = self.widgets['repeat_generation'].value() + full_inputs['multi_images_gen_type'] = self.widgets['multi_images_gen_type'].currentData() + + lora_list_widget = self.widgets['activated_loras'] + selected_items = lora_list_widget.selectedItems() + full_inputs['activated_loras'] = [self.lora_map[item.text()] for item in selected_items if item.text() in self.lora_map] + full_inputs['loras_multipliers'] = self.widgets['loras_multipliers'].toPlainText() + + full_inputs['skip_steps_cache_type'] = self.widgets['skip_steps_cache_type'].currentData() + full_inputs['skip_steps_multiplier'] = self.widgets['skip_steps_multiplier'].currentData() + full_inputs['skip_steps_start_step_perc'] = self.widgets['skip_steps_start_step_perc'].value() + + full_inputs['temporal_upsampling'] = self.widgets['temporal_upsampling'].currentData() + full_inputs['spatial_upsampling'] = self.widgets['spatial_upsampling'].currentData() + full_inputs['film_grain_intensity'] = self.widgets['film_grain_intensity'].value() / 100.0 + full_inputs['film_grain_saturation'] = self.widgets['film_grain_saturation'].value() / 100.0 + + full_inputs['MMAudio_setting'] = self.widgets['MMAudio_setting'].currentData() + full_inputs['MMAudio_prompt'] = self.widgets['MMAudio_prompt'].text() + full_inputs['MMAudio_neg_prompt'] = self.widgets['MMAudio_neg_prompt'].text() + + full_inputs['RIFLEx_setting'] = self.widgets['RIFLEx_setting'].currentData() + full_inputs['force_fps'] = self.widgets['force_fps'].currentData() + full_inputs['override_profile'] = self.widgets['override_profile'].currentData() + full_inputs['multi_prompts_gen_type'] = self.widgets['multi_prompts_gen_type'].currentData() + + full_inputs['slg_switch'] = self.widgets['slg_switch'].currentData() + full_inputs['slg_start_perc'] = self.widgets['slg_start_perc'].value() + full_inputs['slg_end_perc'] = self.widgets['slg_end_perc'].value() + full_inputs['apg_switch'] = self.widgets['apg_switch'].currentData() + full_inputs['cfg_star_switch'] = self.widgets['cfg_star_switch'].currentData() + full_inputs['cfg_zero_step'] = self.widgets['cfg_zero_step'].value() + full_inputs['min_frames_if_references'] = self.widgets['min_frames_if_references'].currentData() + + full_inputs['sliding_window_size'] = self.widgets['sliding_window_size'].value() + full_inputs['sliding_window_overlap'] = self.widgets['sliding_window_overlap'].value() + full_inputs['sliding_window_color_correction_strength'] = self.widgets['sliding_window_color_correction_strength'].value() / 100.0 + full_inputs['sliding_window_overlap_noise'] = self.widgets['sliding_window_overlap_noise'].value() + full_inputs['sliding_window_discard_last_frames'] = self.widgets['sliding_window_discard_last_frames'].value() + + return full_inputs + + def _prepare_state_for_generation(self): + if 'gen' in self.state: + self.state['gen'].pop('abort', None) + self.state['gen'].pop('in_progress', None) + + def _on_generate(self): + try: + is_running = self.thread and self.thread.isRunning() + self._add_task_to_queue_and_update_ui() + if not is_running: + self.start_generation() + except Exception: + import traceback + traceback.print_exc() + + + def _on_add_to_queue(self): + try: + self._add_task_to_queue_and_update_ui() + except Exception: + import traceback + traceback.print_exc() + + def _add_task_to_queue_and_update_ui(self): + self._add_task_to_queue() + self.update_queue_table() + + def _add_task_to_queue(self): + queue_size_before = len(self.state["gen"]["queue"]) + all_inputs = self.collect_inputs() + keys_to_remove = ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality'] + for key in keys_to_remove: + all_inputs.pop(key, None) + + all_inputs['state'] = self.state + wgp.set_model_settings(self.state, self.state['model_type'], all_inputs) + + self.state["validate_success"] = 1 + wgp.process_prompt_and_add_tasks(self.state, self.state['model_type']) + + + def start_generation(self): + if not self.state['gen']['queue']: + return + self._prepare_state_for_generation() + self.generate_btn.setEnabled(False) + self.add_to_queue_btn.setEnabled(True) + + self.thread = QThread() + self.worker = Worker(self.state) + self.worker.moveToThread(self.thread) + + self.thread.started.connect(self.worker.run) + self.worker.finished.connect(self.thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + self.thread.finished.connect(self.on_generation_finished) + + self.worker.status.connect(self.status_label.setText) + self.worker.progress.connect(self.update_progress) + self.worker.preview.connect(self.update_preview) + self.worker.output.connect(self.update_queue_and_gallery) + self.worker.error.connect(self.on_generation_error) + + self.thread.start() + self.update_queue_table() + + def on_generation_finished(self): + time.sleep(0.1) + self.status_label.setText("Finished.") + self.progress_bar.setValue(0) + self.generate_btn.setEnabled(True) + self.add_to_queue_btn.setEnabled(False) + self.thread = None + self.worker = None + self.update_queue_table() + + def on_generation_error(self, err_msg): + msg_box = QMessageBox() + msg_box.setIcon(QMessageBox.Icon.Critical) + msg_box.setText("Generation Error") + msg_box.setInformativeText(str(err_msg)) + msg_box.setWindowTitle("Error") + msg_box.exec() + self.on_generation_finished() + + def update_progress(self, data): + if len(data) > 1 and isinstance(data[0], tuple): + step, total = data[0] + self.progress_bar.setMaximum(total) + self.progress_bar.setValue(step) + self.status_label.setText(str(data[1])) + if step <= 1: + self.update_queue_table() + elif len(data) > 1: + self.status_label.setText(str(data[1])) + + def update_preview(self, pil_image): + if pil_image: + q_image = ImageQt(pil_image) + pixmap = QPixmap.fromImage(q_image) + self.preview_image.setPixmap(pixmap.scaled( + self.preview_image.size(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + )) + + def update_queue_and_gallery(self): + self.update_queue_table() + + file_list = self.state.get('gen', {}).get('file_list', []) + self.output_gallery.clear() + self.output_gallery.addItems(file_list) + if file_list: + self.output_gallery.setCurrentRow(len(file_list) - 1) + self.latest_output_path = file_list[-1] # <-- ADDED for API access + + def update_queue_table(self): + with wgp.lock: + queue = self.state.get('gen', {}).get('queue', []) + is_running = self.thread and self.thread.isRunning() + queue_to_display = queue if is_running else [None] + queue + + table_data = wgp.get_queue_table(queue_to_display) + + self.queue_table.setRowCount(0) + self.queue_table.setRowCount(len(table_data)) + self.queue_table.setColumnCount(4) + self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"]) + + for row_idx, row_data in enumerate(table_data): + prompt_html = row_data[1] + try: + prompt_text = prompt_html.split('>')[1].split('<')[0] + except IndexError: + prompt_text = str(row_data[1]) + + self.queue_table.setItem(row_idx, 0, QTableWidgetItem(str(row_data[0]))) + self.queue_table.setItem(row_idx, 1, QTableWidgetItem(prompt_text)) + self.queue_table.setItem(row_idx, 2, QTableWidgetItem(str(row_data[2]))) + self.queue_table.setItem(row_idx, 3, QTableWidgetItem(str(row_data[3]))) + + self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.queue_table.resizeColumnsToContents() + + def _on_remove_selected_from_queue(self): + selected_row = self.queue_table.currentRow() + if selected_row < 0: + return + + with wgp.lock: + is_running = self.thread and self.thread.isRunning() + offset = 1 if is_running else 0 + + queue = self.state.get('gen', {}).get('queue', []) + if len(queue) > selected_row + offset: + queue.pop(selected_row + offset) + self.update_queue_table() + + def _on_queue_rows_moved(self, source_row, dest_row): + with wgp.lock: + queue = self.state.get('gen', {}).get('queue', []) + is_running = self.thread and self.thread.isRunning() + offset = 1 if is_running else 0 + + real_source_idx = source_row + offset + real_dest_idx = dest_row + offset + + moved_item = queue.pop(real_source_idx) + queue.insert(real_dest_idx, moved_item) + self.update_queue_table() + + def _on_clear_queue(self): + wgp.clear_queue_action(self.state) + self.update_queue_table() + + def _on_abort(self): + if self.worker: + wgp.abort_generation(self.state) + self.status_label.setText("Aborting...") + self.worker._is_running = False + + def _on_release_ram(self): + wgp.release_RAM() + QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.") + + def _on_apply_config_changes(self): + changes = {} + + list_widget = self.widgets['config_transformer_types'] + checked_items = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] + changes['transformer_types_choices'] = checked_items + + list_widget = self.widgets['config_preload_model_policy'] + checked_items = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] + changes['preload_model_policy_choice'] = checked_items + + changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData() + changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText() + + for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]: + changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() + + for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]: + changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() + changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value() + + for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]: + changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() + + for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]: + widget = self.widgets[f'config_{key}'] + changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text() + + changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData() + changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value() + + changes['last_resolution_choice'] = self.widgets['resolution'].currentData() + + try: + msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = wgp.apply_changes(self.state, **changes) + self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.") + + self.header_info.setText(wgp.generate_header(self.state['model_type'], wgp.compile, wgp.attention_mode)) + + if family_mock.choices is not None or choice_mock.choices is not None: + self.update_model_dropdowns(wgp.transformer_type) + self.refresh_ui_from_model_change(wgp.transformer_type) + + except Exception as e: + self.config_status_label.setText(f"Error applying changes: {e}") + import traceback + traceback.print_exc() + + def closeEvent(self, event): + if wgp: + wgp.autosave_queue() + save_main_config() + event.accept() + +# ===================================================================== +# --- START OF API SERVER ADDITION --- +# ===================================================================== +try: + from flask import Flask, request, jsonify + FLASK_AVAILABLE = True +except ImportError: + FLASK_AVAILABLE = False + print("Flask not installed. API server will not be available. Please run: pip install Flask") + +# Global reference to the main window, to be populated after instantiation. +main_window_instance = None +api_server = Flask(__name__) if FLASK_AVAILABLE else None + +def run_api_server(): + """Function to run the Flask server.""" + if api_server: + print("Starting API server on http://127.0.0.1:5100") + api_server.run(port=5100, host='127.0.0.1', debug=False) + +if FLASK_AVAILABLE: + # --- CHANGE: Removed /api/set_model endpoint --- + + @api_server.route('/api/generate', methods=['POST']) + def generate(): + if not main_window_instance: + return jsonify({"error": "Application not ready"}), 503 + + data = request.json + start_frame = data.get('start_frame') + end_frame = data.get('end_frame') + # --- CHANGE: Get duration_sec and model_type --- + duration_sec = data.get('duration_sec') + model_type = data.get('model_type') + start_generation = data.get('start_generation', False) + + # --- CHANGE: Emit the signal with the new parameters --- + main_window_instance.api_bridge.generateSignal.emit( + start_frame, end_frame, duration_sec, model_type, start_generation + ) + + if start_generation: + return jsonify({"message": "Parameters set and generation request sent."}) + else: + return jsonify({"message": "Parameters set without starting generation."}) + + + @api_server.route('/api/latest_output', methods=['GET']) + def get_latest_output(): + if not main_window_instance: + return jsonify({"error": "Application not ready"}), 503 + + path = main_window_instance.latest_output_path + return jsonify({"latest_output_path": path}) + +# ===================================================================== +# --- END OF API SERVER ADDITION --- +# ===================================================================== + +if __name__ == '__main__': + app = QApplication(sys.argv) + + # Create and show the main window + window = MainWindow() + main_window_instance = window # Assign to global for API access + window.show() + + # Start the Flask API server in a separate thread + if FLASK_AVAILABLE: + api_thread = threading.Thread(target=run_api_server, daemon=True) + api_thread.start() + + sys.exit(app.exec()) \ No newline at end of file From 93991812289a3e617b7a0d496557e8d9c21b8c99 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 04:42:13 +1100 Subject: [PATCH 088/155] Update README.md --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fb88c7acd..7a5a7f788 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,16 @@ cd videoeditor python main.py ``` +## Screenshots +image +image +image +image -# WanGP + + + +# What follows is the standard WanGP Readme -----

From 4033106e458836f6215882b83fcbb8da3c1bb9a8 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 04:46:18 +1100 Subject: [PATCH 089/155] Update README.md --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7a5a7f788..6464a18f0 100644 --- a/README.md +++ b/README.md @@ -37,10 +37,11 @@ python main.py ``` ## Screenshots -image -image -image -image +image +image +image +image + From d239ceb2272517779628aa8eacc075a7e52accbc Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Thu, 9 Oct 2025 22:01:00 +0200 Subject: [PATCH 090/155] fixed extract settings crash --- shared/utils/loras_mutipliers.py | 2 +- wgp.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py index 5e024a976..dda805627 100644 --- a/shared/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -9,7 +9,7 @@ def preparse_loras_multipliers(loras_multipliers): loras_mult_choices_list = loras_multipliers.replace("\r", "").split("\n") loras_mult_choices_list = [multi.strip() for multi in loras_mult_choices_list if len(multi)>0 and not multi.startswith("#")] loras_multipliers = " ".join(loras_mult_choices_list) - return loras_multipliers.replace("|"," ").split(" ") + return loras_multipliers.replace("|"," ").strip().split(" ") def expand_slist(slists_dict, mult_no, num_inference_steps, model_switch_step, model_switch_step2 ): def expand_one(slist, num_inference_steps): diff --git a/wgp.py b/wgp.py index c0b498260..3bc3e54d2 100644 --- a/wgp.py +++ b/wgp.py @@ -6636,13 +6636,13 @@ def use_video_settings(state, input_file_list, choice): else: gr.Info(f"Settings Loaded from Video with prompt '{prompt[:100]}'") if models_compatible: - return gr.update(), gr.update(), str(time.time()) + return gr.update(), gr.update(), gr.update(), str(time.time()) else: return *generate_dropdown_model_list(model_type), gr.update() else: gr.Info(f"No Video is Selected") - return gr.update(), gr.update() + return gr.update(), gr.update(), gr.update(), gr.update() loras_url_cache = None def update_loras_url_cache(lora_dir, loras_selected): if loras_selected is None: return None From e3ad47e7a39674ae5f77c2be88cbd609d3949767 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 07:59:47 +1100 Subject: [PATCH 091/155] add undo/redo, delete clips, project media window, recent projects. fixed ai plugin stalling when no server present --- videoeditor/main.py | 842 +++++++++++++++----- videoeditor/plugins/ai_frame_joiner/main.py | 5 +- videoeditor/undo.py | 114 +++ 3 files changed, 779 insertions(+), 182 deletions(-) create mode 100644 videoeditor/undo.py diff --git a/videoeditor/main.py b/videoeditor/main.py index 717628c42..a3f155e09 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -5,45 +5,18 @@ import re import json import ffmpeg +import copy from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QScrollArea, QFrame, QProgressBar, QDialog, - QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget) + QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox) from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction, - QPixmap, QImage) + QPixmap, QImage, QDrag) from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread, - pyqtSignal, QTimer, QByteArray) + pyqtSignal, QTimer, QByteArray, QMimeData) from plugins import PluginManager, ManagePluginsDialog -import requests -import zipfile -from tqdm import tqdm - -def download_ffmpeg(): - if os.name != 'nt': return - exes = ['ffmpeg.exe', 'ffprobe.exe', 'ffplay.exe'] - if all(os.path.exists(e) for e in exes): return - api_url = 'https://api.github.com/repos/GyanD/codexffmpeg/releases/latest' - r = requests.get(api_url, headers={'Accept': 'application/vnd.github+json'}) - assets = r.json().get('assets', []) - zip_asset = next((a for a in assets if 'essentials_build.zip' in a['name']), None) - if not zip_asset: return - zip_url = zip_asset['browser_download_url'] - zip_name = zip_asset['name'] - with requests.get(zip_url, stream=True) as resp: - total = int(resp.headers.get('Content-Length', 0)) - with open(zip_name, 'wb') as f, tqdm(total=total, unit='B', unit_scale=True) as pbar: - for chunk in resp.iter_content(chunk_size=8192): - f.write(chunk) - pbar.update(len(chunk)) - with zipfile.ZipFile(zip_name) as z: - for f in z.namelist(): - if f.endswith(tuple(exes)) and '/bin/' in f: - z.extract(f) - os.rename(f, os.path.basename(f)) - os.remove(zip_name) - -download_ffmpeg() +from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand class TimelineClip: def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, group_id): @@ -111,15 +84,19 @@ class TimelineWidget(QWidget): AUDIO_TRACKS_SEPARATOR_Y = 15 split_requested = pyqtSignal(object) + delete_clip_requested = pyqtSignal(object) playhead_moved = pyqtSignal(float) split_region_requested = pyqtSignal(list) split_all_regions_requested = pyqtSignal(list) join_region_requested = pyqtSignal(list) join_all_regions_requested = pyqtSignal(list) + delete_region_requested = pyqtSignal(list) + delete_all_regions_requested = pyqtSignal(list) add_track = pyqtSignal(str) remove_track = pyqtSignal(str) operation_finished = pyqtSignal() - context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent') + context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent') + clips_moved = pyqtSignal(list) def __init__(self, timeline_model, settings, parent=None): super().__init__(parent) @@ -128,6 +105,7 @@ def __init__(self, timeline_model, settings, parent=None): self.playhead_pos_sec = 0.0 self.setMinimumHeight(300) self.setMouseTracking(True) + self.setAcceptDrops(True) self.selection_regions = [] self.dragging_clip = None self.dragging_linked_clip = None @@ -135,7 +113,7 @@ def __init__(self, timeline_model, settings, parent=None): self.creating_selection_region = False self.dragging_selection_region = None self.drag_start_pos = QPoint() - self.drag_original_timeline_start = 0 + self.drag_original_clip_states = {} # Store {'clip_id': (start_sec, track_index)} self.selection_drag_start_sec = 0.0 self.drag_selection_start_values = None @@ -147,6 +125,11 @@ def __init__(self, timeline_model, settings, parent=None): self.video_tracks_y_start = 0 self.audio_tracks_y_start = 0 + + self.drag_over_active = False + self.drag_over_rect = QRectF() + self.drag_over_audio_rect = QRectF() + self.drag_has_audio = False def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.PIXELS_PER_SECOND) def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.PIXELS_PER_SECOND if x > self.HEADER_WIDTH else 0.0 @@ -160,6 +143,16 @@ def paintEvent(self, event): self.draw_timescale(painter) self.draw_tracks_and_clips(painter) self.draw_selections(painter) + + if self.drag_over_active: + painter.fillRect(self.drag_over_rect, QColor(0, 255, 0, 80)) + if self.drag_has_audio: + painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80)) + painter.setPen(QColor(0, 255, 0, 150)) + painter.drawRect(self.drag_over_rect) + if self.drag_has_audio: + painter.drawRect(self.drag_over_audio_rect) + self.draw_playhead(painter) total_width = self.sec_to_x(self.timeline.get_total_duration()) + 100 @@ -368,14 +361,20 @@ def mousePressEvent(self, event: QMouseEvent): self.dragging_playhead = False self.dragging_selection_region = None self.creating_selection_region = False + self.drag_original_clip_states.clear() for clip in reversed(self.timeline.clips): clip_rect = self.get_clip_rect(clip) if clip_rect.contains(QPointF(event.pos())): self.dragging_clip = clip + self.drag_original_clip_states[clip.id] = (clip.timeline_start_sec, clip.track_index) + self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id), None) + if self.dragging_linked_clip: + self.drag_original_clip_states[self.dragging_linked_clip.id] = \ + (self.dragging_linked_clip.timeline_start_sec, self.dragging_linked_clip.track_index) + self.drag_start_pos = event.pos() - self.drag_original_timeline_start = clip.timeline_start_sec break if not self.dragging_clip: @@ -429,25 +428,23 @@ def mouseMoveEvent(self, event: QMouseEvent): elif self.dragging_clip: self.highlighted_track_info = None new_track_info = self.y_to_track_info(event.pos().y()) + + original_start_sec, _ = self.drag_original_clip_states[self.dragging_clip.id] + if new_track_info: new_track_type, new_track_index = new_track_info if new_track_type == self.dragging_clip.track_type: self.dragging_clip.track_index = new_track_index if self.dragging_linked_clip: + # For now, linked clips move to same-number track index self.dragging_linked_clip.track_index = new_track_index - if self.dragging_clip.track_type == 'video': - if new_track_index > self.timeline.num_audio_tracks: - self.timeline.num_audio_tracks = new_track_index - self.highlighted_track_info = ('audio', new_track_index) - else: - if new_track_index > self.timeline.num_video_tracks: - self.timeline.num_video_tracks = new_track_index - self.highlighted_track_info = ('video', new_track_index) + delta_x = event.pos().x() - self.drag_start_pos.x() time_delta = delta_x / self.PIXELS_PER_SECOND - new_start_time = self.drag_original_timeline_start + time_delta + new_start_time = original_start_sec + time_delta + # Basic collision detection (can be improved) for other_clip in self.timeline.clips: if other_clip.id == self.dragging_clip.id: continue if self.dragging_linked_clip and other_clip.id == self.dragging_linked_clip.id: continue @@ -485,14 +482,130 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.dragging_playhead = False if self.dragging_clip: + move_data = [] + # Check main dragged clip + orig_start, orig_track = self.drag_original_clip_states[self.dragging_clip.id] + if orig_start != self.dragging_clip.timeline_start_sec or orig_track != self.dragging_clip.track_index: + move_data.append({ + 'clip_id': self.dragging_clip.id, + 'old_start': orig_start, 'new_start': self.dragging_clip.timeline_start_sec, + 'old_track': orig_track, 'new_track': self.dragging_clip.track_index + }) + # Check linked clip + if self.dragging_linked_clip: + orig_start_link, orig_track_link = self.drag_original_clip_states[self.dragging_linked_clip.id] + if orig_start_link != self.dragging_linked_clip.timeline_start_sec or orig_track_link != self.dragging_linked_clip.track_index: + move_data.append({ + 'clip_id': self.dragging_linked_clip.id, + 'old_start': orig_start_link, 'new_start': self.dragging_linked_clip.timeline_start_sec, + 'old_track': orig_track_link, 'new_track': self.dragging_linked_clip.track_index + }) + + if move_data: + self.clips_moved.emit(move_data) + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) self.highlighted_track_info = None self.operation_finished.emit() + self.dragging_clip = None self.dragging_linked_clip = None + self.drag_original_clip_states.clear() self.update() + def dragEnterEvent(self, event): + if event.mimeData().hasFormat('application/x-vnd.video.filepath'): + event.acceptProposedAction() + else: + event.ignore() + + def dragLeaveEvent(self, event): + self.drag_over_active = False + self.update() + + def dragMoveEvent(self, event): + mime_data = event.mimeData() + if not mime_data.hasFormat('application/x-vnd.video.filepath'): + event.ignore() + return + + event.acceptProposedAction() + + json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8')) + duration = json_data['duration'] + self.drag_has_audio = json_data['has_audio'] + + pos = event.position() + track_info = self.y_to_track_info(pos.y()) + start_sec = self.x_to_sec(pos.x()) + + if track_info: + self.drag_over_active = True + track_type, track_index = track_info + + width = int(duration * self.PIXELS_PER_SECOND) + x = self.sec_to_x(start_sec) + + if track_type == 'video': + visual_index = self.timeline.num_video_tracks - track_index + y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT + self.drag_over_rect = QRectF(x, y, width, self.TRACK_HEIGHT) + if self.drag_has_audio: + audio_y = self.audio_tracks_y_start # Default to track 1 + self.drag_over_audio_rect = QRectF(x, audio_y, width, self.TRACK_HEIGHT) + + elif track_type == 'audio': + visual_index = track_index - 1 + y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT + self.drag_over_audio_rect = QRectF(x, y, width, self.TRACK_HEIGHT) + # For now, just show video on track 1 if dragging onto audio + video_y = self.video_tracks_y_start + (self.timeline.num_video_tracks - 1) * self.TRACK_HEIGHT + self.drag_over_rect = QRectF(x, video_y, width, self.TRACK_HEIGHT) + else: + self.drag_over_active = False + + self.update() + + def dropEvent(self, event): + self.drag_over_active = False + self.update() + + mime_data = event.mimeData() + if not mime_data.hasFormat('application/x-vnd.video.filepath'): + return + + json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8')) + file_path = json_data['path'] + duration = json_data['duration'] + has_audio = json_data['has_audio'] + + pos = event.position() + start_sec = self.x_to_sec(pos.x()) + track_info = self.y_to_track_info(pos.y()) + + if not track_info: + return + + drop_track_type, drop_track_index = track_info + video_track_idx = 1 + audio_track_idx = 1 if has_audio else None + + if drop_track_type == 'video': + video_track_idx = drop_track_index + elif drop_track_type == 'audio': + audio_track_idx = drop_track_index + + main_window = self.window() + main_window._add_clip_to_timeline( + source_path=file_path, + timeline_start_sec=start_sec, + duration_sec=duration, + clip_start_sec=0, + video_track_index=video_track_idx, + audio_track_index=audio_track_idx + ) + def contextMenuEvent(self, event: 'QContextMenuEvent'): menu = QMenu(self) @@ -502,6 +615,8 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'): split_all_action = menu.addAction("Split All Regions") join_this_action = menu.addAction("Join This Region") join_all_action = menu.addAction("Join All Regions") + delete_this_action = menu.addAction("Delete This Region") + delete_all_action = menu.addAction("Delete All Regions") menu.addSeparator() clear_this_action = menu.addAction("Clear This Region") clear_all_action = menu.addAction("Clear All Regions") @@ -509,6 +624,8 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'): split_all_action.triggered.connect(lambda: self.split_all_regions_requested.emit(self.selection_regions)) join_this_action.triggered.connect(lambda: self.join_region_requested.emit(region_at_pos)) join_all_action.triggered.connect(lambda: self.join_all_regions_requested.emit(self.selection_regions)) + delete_this_action.triggered.connect(lambda: self.delete_region_requested.emit(region_at_pos)) + delete_all_action.triggered.connect(lambda: self.delete_all_regions_requested.emit(self.selection_regions)) clear_this_action.triggered.connect(lambda: self.clear_region(region_at_pos)) clear_all_action.triggered.connect(self.clear_all_regions) @@ -521,10 +638,12 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'): if clip_at_pos: if not menu.isEmpty(): menu.addSeparator() split_action = menu.addAction("Split Clip") + delete_action = menu.addAction("Delete Clip") playhead_time = self.playhead_pos_sec is_playhead_over_clip = (clip_at_pos.timeline_start_sec < playhead_time < clip_at_pos.timeline_end_sec) split_action.setEnabled(is_playhead_over_clip) split_action.triggered.connect(lambda: self.split_requested.emit(clip_at_pos)) + delete_action.triggered.connect(lambda: self.delete_clip_requested.emit(clip_at_pos)) self.context_menu_requested.emit(menu, event) @@ -540,22 +659,111 @@ def clear_all_regions(self): self.selection_regions.clear() self.update() + class SettingsDialog(QDialog): def __init__(self, parent_settings, parent=None): super().__init__(parent) self.setWindowTitle("Settings") self.setMinimumWidth(350) layout = QVBoxLayout(self) - self.seek_anywhere_checkbox = QCheckBox("Allow seeking by clicking anywhere on the timeline") - self.seek_anywhere_checkbox.setChecked(parent_settings.get("allow_seek_anywhere", False)) - layout.addWidget(self.seek_anywhere_checkbox) + self.confirm_on_exit_checkbox = QCheckBox("Confirm before exiting") + self.confirm_on_exit_checkbox.setChecked(parent_settings.get("confirm_on_exit", True)) + layout.addWidget(self.confirm_on_exit_checkbox) button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) button_box.accepted.connect(self.accept) button_box.rejected.connect(self.reject) layout.addWidget(button_box) def get_settings(self): - return {"allow_seek_anywhere": self.seek_anywhere_checkbox.isChecked()} + return { + "confirm_on_exit": self.confirm_on_exit_checkbox.isChecked() + } + +class MediaListWidget(QListWidget): + def __init__(self, main_window, parent=None): + super().__init__(parent) + self.main_window = main_window + self.setDragEnabled(True) + self.setAcceptDrops(False) + self.setSelectionMode(QListWidget.SelectionMode.ExtendedSelection) + + def startDrag(self, supportedActions): + drag = QDrag(self) + mime_data = QMimeData() + + item = self.currentItem() + if not item: return + + path = item.data(Qt.ItemDataRole.UserRole) + media_info = self.main_window.media_properties.get(path) + if not media_info: return + + payload = { + "path": path, + "duration": media_info['duration'], + "has_audio": media_info['has_audio'] + } + + mime_data.setData('application/x-vnd.video.filepath', QByteArray(json.dumps(payload).encode('utf-8'))) + drag.setMimeData(mime_data) + drag.exec(Qt.DropAction.CopyAction) + +class ProjectMediaWidget(QWidget): + media_removed = pyqtSignal(str) + add_media_requested = pyqtSignal() + add_to_timeline_requested = pyqtSignal(str) + + def __init__(self, parent=None): + super().__init__(parent) + self.main_window = parent + layout = QVBoxLayout(self) + layout.setContentsMargins(2, 2, 2, 2) + + self.media_list = MediaListWidget(self.main_window, self) + self.media_list.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + self.media_list.customContextMenuRequested.connect(self.show_context_menu) + layout.addWidget(self.media_list) + + button_layout = QHBoxLayout() + add_button = QPushButton("Add") + remove_button = QPushButton("Remove") + button_layout.addWidget(add_button) + button_layout.addWidget(remove_button) + layout.addLayout(button_layout) + + add_button.clicked.connect(self.add_media_requested.emit) + remove_button.clicked.connect(self.remove_selected_media) + def show_context_menu(self, pos): + item = self.media_list.itemAt(pos) + if not item: + return + + menu = QMenu() + add_action = menu.addAction("Add to timeline at playhead") + + action = menu.exec(self.media_list.mapToGlobal(pos)) + + if action == add_action: + file_path = item.data(Qt.ItemDataRole.UserRole) + self.add_to_timeline_requested.emit(file_path) + + def add_media_item(self, file_path): + if not any(self.media_list.item(i).data(Qt.ItemDataRole.UserRole) == file_path for i in range(self.media_list.count())): + item = QListWidgetItem(os.path.basename(file_path)) + item.setData(Qt.ItemDataRole.UserRole, file_path) + self.media_list.addItem(item) + + def remove_selected_media(self): + selected_items = self.media_list.selectedItems() + if not selected_items: return + + for item in selected_items: + file_path = item.data(Qt.ItemDataRole.UserRole) + self.media_removed.emit(file_path) + self.media_list.takeItem(self.media_list.row(item)) + + def clear_list(self): + self.media_list.clear() class MainWindow(QMainWindow): def __init__(self, project_to_load=None): @@ -565,6 +773,9 @@ def __init__(self, project_to_load=None): self.setDockOptions(QMainWindow.DockOption.AnimatedDocks | QMainWindow.DockOption.AllowNestedDocks) self.timeline = Timeline() + self.undo_stack = UndoStack() + self.media_pool = [] + self.media_properties = {} # To store duration, has_audio etc. self.export_thread = None self.export_worker = None self.current_project_path = None @@ -583,6 +794,30 @@ def __init__(self, project_to_load=None): self.playback_process = None self.playback_clip = None + self._setup_ui() + self._connect_signals() + + self.plugin_manager.load_enabled_plugins_from_settings(self.settings.get("enabled_plugins", [])) + self._apply_loaded_settings() + self.seek_preview(0) + + if not self.settings_file_was_loaded: self._save_settings() + if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load)) + + def _get_current_timeline_state(self): + """Helper to capture a snapshot of the timeline for undo/redo.""" + return ( + copy.deepcopy(self.timeline.clips), + self.timeline.num_video_tracks, + self.timeline.num_audio_tracks + ) + + def _setup_ui(self): + self.media_dock = QDockWidget("Project Media", self) + self.project_media_widget = ProjectMediaWidget(self) + self.media_dock.setWidget(self.project_media_widget) + self.addDockWidget(Qt.DockWidgetArea.LeftDockWidgetArea, self.media_dock) + self.splitter = QSplitter(Qt.Orientation.Vertical) self.preview_widget = QLabel() @@ -592,14 +827,14 @@ def __init__(self, project_to_load=None): self.preview_widget.setStyleSheet("background-color: black; color: white;") self.splitter.addWidget(self.preview_widget) - self.timeline_widget = TimelineWidget(self.timeline, self.settings) + self.timeline_widget = TimelineWidget(self.timeline, self.settings, self) self.timeline_scroll_area = QScrollArea() self.timeline_scroll_area.setWidgetResizable(True) self.timeline_scroll_area.setWidget(self.timeline_widget) self.timeline_scroll_area.setFrameShape(QFrame.Shape.NoFrame) self.timeline_scroll_area.setMinimumHeight(250) self.splitter.addWidget(self.timeline_scroll_area) - + container_widget = QWidget() main_layout = QVBoxLayout(container_widget) main_layout.setContentsMargins(0,0,0,0) @@ -634,7 +869,8 @@ def __init__(self, project_to_load=None): self.managed_widgets = { 'preview': {'widget': self.preview_widget, 'name': 'Video Preview', 'action': None}, - 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None} + 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None}, + 'project_media': {'widget': self.media_dock, 'name': 'Project Media', 'action': None} } self.plugin_menu_actions = {} self.windows_menu = None @@ -643,14 +879,20 @@ def __init__(self, project_to_load=None): self.splitter_save_timer = QTimer(self) self.splitter_save_timer.setSingleShot(True) self.splitter_save_timer.timeout.connect(self._save_settings) + + def _connect_signals(self): self.splitter.splitterMoved.connect(self.on_splitter_moved) self.timeline_widget.split_requested.connect(self.split_clip_at_playhead) + self.timeline_widget.delete_clip_requested.connect(self.delete_clip) self.timeline_widget.playhead_moved.connect(self.seek_preview) + self.timeline_widget.clips_moved.connect(self.on_clips_moved) self.timeline_widget.split_region_requested.connect(self.on_split_region) self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions) self.timeline_widget.join_region_requested.connect(self.on_join_region) self.timeline_widget.join_all_regions_requested.connect(self.on_join_all_regions) + self.timeline_widget.delete_region_requested.connect(self.on_delete_region) + self.timeline_widget.delete_all_regions_requested.connect(self.on_delete_all_regions) self.timeline_widget.add_track.connect(self.add_track) self.timeline_widget.remove_track.connect(self.remove_track) self.timeline_widget.operation_finished.connect(self.prune_empty_tracks) @@ -661,16 +903,57 @@ def __init__(self, project_to_load=None): self.frame_forward_button.clicked.connect(lambda: self.step_frame(1)) self.playback_timer.timeout.connect(self.advance_playback_frame) - self.plugin_manager.load_enabled_plugins_from_settings(self.settings.get("enabled_plugins", [])) - self._apply_loaded_settings() - self.seek_preview(0) + self.project_media_widget.add_media_requested.connect(self.add_video_clip) + self.project_media_widget.media_removed.connect(self.on_media_removed_from_pool) + self.project_media_widget.add_to_timeline_requested.connect(self.on_add_to_timeline_at_playhead) - if not self.settings_file_was_loaded: self._save_settings() - if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load)) + self.undo_stack.history_changed.connect(self.update_undo_redo_actions) + self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo) + + def on_timeline_changed_by_undo(self): + """Called when undo/redo modifies the timeline.""" + self.prune_empty_tracks() + self.timeline_widget.update() + # You may want to update other UI elements here as well + self.status_label.setText("Operation undone/redone.") + + def update_undo_redo_actions(self): + self.undo_action.setEnabled(self.undo_stack.can_undo()) + self.undo_action.setText(f"Undo {self.undo_stack.undo_text()}" if self.undo_stack.can_undo() else "Undo") + + self.redo_action.setEnabled(self.undo_stack.can_redo()) + self.redo_action.setText(f"Redo {self.undo_stack.redo_text()}" if self.undo_stack.can_redo() else "Redo") + + def on_clips_moved(self, move_data): + command = MoveClipsCommand("Move Clip", self.timeline, move_data) + # The change is already applied visually, so we push without redoing + self.undo_stack.undo_stack.append(command) + self.undo_stack.redo_stack.clear() + self.undo_stack.history_changed.emit() + self.status_label.setText("Clip moved.") + + def on_add_to_timeline_at_playhead(self, file_path): + media_info = self.media_properties.get(file_path) + if not media_info: + self.status_label.setText(f"Error: Could not find properties for {os.path.basename(file_path)}") + return + + playhead_pos = self.timeline_widget.playhead_pos_sec + duration = media_info['duration'] + has_audio = media_info['has_audio'] + + self._add_clip_to_timeline( + source_path=file_path, + timeline_start_sec=playhead_pos, + duration_sec=duration, + clip_start_sec=0.0, + video_track_index=1, + audio_track_index=1 if has_audio else None + ) def prune_empty_tracks(self): pruned_something = False - + # Prune Video Tracks while self.timeline.num_video_tracks > 1: highest_track_index = self.timeline.num_video_tracks is_track_occupied = any(c for c in self.timeline.clips @@ -680,7 +963,8 @@ def prune_empty_tracks(self): else: self.timeline.num_video_tracks -= 1 pruned_something = True - + + # Prune Audio Tracks while self.timeline.num_audio_tracks > 1: highest_track_index = self.timeline.num_audio_tracks is_track_occupied = any(c for c in self.timeline.clips @@ -690,42 +974,64 @@ def prune_empty_tracks(self): else: self.timeline.num_audio_tracks -= 1 pruned_something = True - + if pruned_something: self.timeline_widget.update() - def add_track(self, track_type): + old_state = self._get_current_timeline_state() if track_type == 'video': self.timeline.num_video_tracks += 1 elif track_type == 'audio': self.timeline.num_audio_tracks += 1 - self.timeline_widget.update() + new_state = self._get_current_timeline_state() + + command = TimelineStateChangeCommand(f"Add {track_type.capitalize()} Track", self.timeline, *old_state, *new_state) + self.undo_stack.push(command) + # Manually undo the direct change, because push() will redo() it. + command.undo() + self.undo_stack.push(command) def remove_track(self, track_type): + old_state = self._get_current_timeline_state() if track_type == 'video' and self.timeline.num_video_tracks > 1: self.timeline.num_video_tracks -= 1 elif track_type == 'audio' and self.timeline.num_audio_tracks > 1: self.timeline.num_audio_tracks -= 1 - self.timeline_widget.update() + else: + return # No change made + new_state = self._get_current_timeline_state() + + command = TimelineStateChangeCommand(f"Remove {track_type.capitalize()} Track", self.timeline, *old_state, *new_state) + command.undo() # Invert state because we already made the change + self.undo_stack.push(command) + def _create_menu_bar(self): menu_bar = self.menuBar() file_menu = menu_bar.addMenu("&File") new_action = QAction("&New Project", self); new_action.triggered.connect(self.new_project) open_action = QAction("&Open Project...", self); open_action.triggered.connect(self.open_project) + self.recent_menu = file_menu.addMenu("Recent") save_action = QAction("&Save Project As...", self); save_action.triggered.connect(self.save_project_as) - add_video_action = QAction("&Add Video...", self); add_video_action.triggered.connect(self.add_video_clip) + add_video_action = QAction("&Add Media to Project...", self); add_video_action.triggered.connect(self.add_video_clip) export_action = QAction("&Export Video...", self); export_action.triggered.connect(self.export_video) settings_action = QAction("Se&ttings...", self); settings_action.triggered.connect(self.open_settings_dialog) exit_action = QAction("E&xit", self); exit_action.triggered.connect(self.close) - file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addAction(save_action) + file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addSeparator() + file_menu.addAction(save_action) file_menu.addSeparator(); file_menu.addAction(add_video_action); file_menu.addAction(export_action) file_menu.addSeparator(); file_menu.addAction(settings_action); file_menu.addSeparator(); file_menu.addAction(exit_action) - + self._update_recent_files_menu() + edit_menu = menu_bar.addMenu("&Edit") + self.undo_action = QAction("Undo", self); self.undo_action.setShortcut("Ctrl+Z"); self.undo_action.triggered.connect(self.undo_stack.undo) + self.redo_action = QAction("Redo", self); self.redo_action.setShortcut("Ctrl+Y"); self.redo_action.triggered.connect(self.undo_stack.redo) + edit_menu.addAction(self.undo_action); edit_menu.addAction(self.redo_action); edit_menu.addSeparator() + split_action = QAction("Split Clip at Playhead", self); split_action.triggered.connect(self.split_clip_at_playhead) edit_menu.addAction(split_action) + self.update_undo_redo_actions() # Set initial state plugins_menu = menu_bar.addMenu("&Plugins") for name, data in self.plugin_manager.plugins.items(): @@ -742,8 +1048,14 @@ def _create_menu_bar(self): self.windows_menu = menu_bar.addMenu("&Windows") for key, data in self.managed_widgets.items(): + if data['widget'] is self.preview_widget: continue # Preview is part of splitter, not a dock action = QAction(data['name'], self, checkable=True) - action.toggled.connect(lambda checked, k=key: self.toggle_widget_visibility(k, checked)) + if hasattr(data['widget'], 'visibilityChanged'): + action.toggled.connect(data['widget'].setVisible) + data['widget'].visibilityChanged.connect(action.setChecked) + else: + action.toggled.connect(lambda checked, k=key: self.toggle_widget_visibility(k, checked)) + data['action'] = action self.windows_menu.addAction(action) @@ -783,6 +1095,10 @@ def _set_project_properties_from_clip(self, source_path): return False def get_frame_data_at_time(self, time_sec): + """ + Plugin API: Extracts raw frame data at a specific time. + Returns a tuple of (bytes, width, height) or (None, 0, 0) on failure. + """ clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) if not clip_at_time: return (None, 0, 0) @@ -858,7 +1174,7 @@ def advance_playback_frame(self): def _load_settings(self): self.settings_file_was_loaded = False - defaults = {"allow_seek_anywhere": False, "window_visibility": {"preview": True, "timeline": True}, "splitter_state": None, "enabled_plugins": []} + defaults = {"window_visibility": {"project_media": True}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True} if os.path.exists(self.settings_file): try: with open(self.settings_file, "r") as f: self.settings = json.load(f) @@ -874,33 +1190,27 @@ def _save_settings(self): key: data['action'].isChecked() for key, data in self.managed_widgets.items() if data.get('action') } - self.settings["window_visibility"] = visibility_to_save self.settings['enabled_plugins'] = self.plugin_manager.get_enabled_plugin_names() try: - with open(self.settings_file, "w") as f: - json.dump(self.settings, f, indent=4) - except IOError as e: - print(f"Error saving settings: {e}") + with open(self.settings_file, "w") as f: json.dump(self.settings, f, indent=4) + except IOError as e: print(f"Error saving settings: {e}") def _apply_loaded_settings(self): visibility_settings = self.settings.get("window_visibility", {}) for key, data in self.managed_widgets.items(): - if data.get('plugin'): - continue - + if data.get('plugin'): continue is_visible = visibility_settings.get(key, True) - data['widget'].setVisible(is_visible) + if data['widget'] is not self.preview_widget: # preview handled by splitter state + data['widget'].setVisible(is_visible) if data['action']: data['action'].setChecked(is_visible) - splitter_state = self.settings.get("splitter_state") if splitter_state: self.splitter.restoreState(QByteArray.fromHex(splitter_state.encode('ascii'))) def on_splitter_moved(self, pos, index): self.splitter_save_timer.start(500) def toggle_widget_visibility(self, key, checked): - if self.is_shutting_down: - return + if self.is_shutting_down: return if key in self.managed_widgets: self.managed_widgets[key]['widget'].setVisible(checked) self._save_settings() @@ -912,19 +1222,25 @@ def open_settings_dialog(self): def new_project(self): self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1 + self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list() self.current_project_path = None; self.stop_playback(); self.timeline_widget.update() - self.status_label.setText("New project created. Add video clips to begin.") + self.undo_stack = UndoStack() + self.undo_stack.history_changed.connect(self.update_undo_redo_actions) + self.update_undo_redo_actions() + self.status_label.setText("New project created. Add media to begin.") def save_project_as(self): path, _ = QFileDialog.getSaveFileName(self, "Save Project", "", "JSON Project Files (*.json)") if not path: return project_data = { + "media_pool": self.media_pool, "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "group_id": c.group_id} for c in self.timeline.clips], "settings": {"num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks} } try: with open(path, "w") as f: json.dump(project_data, f, indent=4) self.current_project_path = path; self.status_label.setText(f"Project saved to {os.path.basename(path)}") + self._add_to_recent_files(path) except Exception as e: self.status_label.setText(f"Error saving project: {e}") def open_project(self): @@ -934,10 +1250,14 @@ def open_project(self): def _load_project_from_path(self, path): try: with open(path, "r") as f: project_data = json.load(f) - self.timeline.clips.clear() + self.new_project() # Clear existing state + + self.media_pool = project_data.get("media_pool", []) + for p in self.media_pool: self._add_media_to_pool(p) + for clip_data in project_data["clips"]: if not os.path.exists(clip_data["source_path"]): - self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.timeline.clips.clear(); self.timeline_widget.update(); return + self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.new_project(); return self.timeline.add_clip(TimelineClip(**clip_data)) project_settings = project_data.get("settings", {}) @@ -949,35 +1269,77 @@ def _load_project_from_path(self, path): self.prune_empty_tracks() self.timeline_widget.update(); self.stop_playback() self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.") + self._add_to_recent_files(path) except Exception as e: self.status_label.setText(f"Error opening project: {e}") - def add_video_clip(self): - file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.mov *.avi)") - if not file_path: return - if not self.timeline.clips: - if not self._set_project_properties_from_clip(file_path): self.status_label.setText("Error: Could not determine video properties from file."); return + def _add_to_recent_files(self, path): + recent = self.settings.get("recent_files", []) + if path in recent: recent.remove(path) + recent.insert(0, path) + self.settings["recent_files"] = recent[:10] + self._update_recent_files_menu() + self._save_settings() + + def _update_recent_files_menu(self): + self.recent_menu.clear() + recent_files = self.settings.get("recent_files", []) + for path in recent_files: + if os.path.exists(path): + action = QAction(os.path.basename(path), self) + action.triggered.connect(lambda checked, p=path: self._load_project_from_path(p)) + self.recent_menu.addAction(action) + + def _add_media_to_pool(self, file_path): + if file_path in self.media_pool: return True try: self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents() probe = ffmpeg.probe(file_path) video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) if not video_stream: raise ValueError("No video stream found.") + duration = float(video_stream.get('duration', probe['format'].get('duration', 0))) has_audio = any(s['codec_type'] == 'audio' for s in probe['streams']) - timeline_start = self.timeline.get_total_duration() - - self._add_clip_to_timeline( - source_path=file_path, - timeline_start_sec=timeline_start, - duration_sec=duration, - clip_start_sec=0, - video_track_index=1, - audio_track_index=1 if has_audio else None - ) + + self.media_properties[file_path] = {'duration': duration, 'has_audio': has_audio} + self.media_pool.append(file_path) + self.project_media_widget.add_media_item(file_path) + + self.status_label.setText(f"Added {os.path.basename(file_path)} to project.") + return True + except Exception as e: + self.status_label.setText(f"Error probing file: {e}") + return False + + def on_media_removed_from_pool(self, file_path): + old_state = self._get_current_timeline_state() + + if file_path in self.media_pool: self.media_pool.remove(file_path) + if file_path in self.media_properties: del self.media_properties[file_path] + + clips_to_remove = [c for c in self.timeline.clips if c.source_path == file_path] + for clip in clips_to_remove: self.timeline.clips.remove(clip) + + new_state = self._get_current_timeline_state() + command = TimelineStateChangeCommand("Remove Media From Project", self.timeline, *old_state, *new_state) + command.undo() + self.undo_stack.push(command) + - self.timeline_widget.update(); self.seek_preview(self.timeline_widget.playhead_pos_sec) - except Exception as e: self.status_label.setText(f"Error adding file: {e}") + def add_video_clip(self): + file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.mov *.avi)") + if not file_path: return + + if not self.timeline.clips and not self.media_pool: + if not self._set_project_properties_from_clip(file_path): + self.status_label.setText("Error: Could not determine video properties from file."); return + + if self._add_media_to_pool(file_path): + self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.") + else: + self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.") def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, clip_start_sec=0.0, video_track_index=None, audio_track_index=None): + old_state = self._get_current_timeline_state() group_id = str(uuid.uuid4()) if video_track_index is not None: @@ -988,14 +1350,16 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, c audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', group_id) self.timeline.add_clip(audio_clip) - self.timeline_widget.update() - self.status_label.setText(f"Added {os.path.basename(source_path)}.") + new_state = self._get_current_timeline_state() + command = TimelineStateChangeCommand("Add Clip", self.timeline, *old_state, *new_state) + # We already performed the action, so we need to undo it before letting the stack redo it. + command.undo() + self.undo_stack.push(command) def _split_at_time(self, clip_to_split, time_sec, new_group_id=None): if not (clip_to_split.timeline_start_sec < time_sec < clip_to_split.timeline_end_sec): return False split_point = time_sec - clip_to_split.timeline_start_sec orig_dur = clip_to_split.duration_sec - group_id_for_new_clip = new_group_id if new_group_id is not None else clip_to_split.group_id new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, group_id_for_new_clip) @@ -1006,88 +1370,199 @@ def _split_at_time(self, clip_to_split, time_sec, new_group_id=None): def split_clip_at_playhead(self, clip_to_split=None): playhead_time = self.timeline_widget.playhead_pos_sec if not clip_to_split: - clip_to_split = next((c for c in self.timeline.clips if c.timeline_start_sec < playhead_time < c.timeline_end_sec), None) - - if not clip_to_split: - self.status_label.setText("Playhead is not over a clip to split.") - return - + clips_at_playhead = [c for c in self.timeline.clips if c.timeline_start_sec < playhead_time < c.timeline_end_sec] + if not clips_at_playhead: + self.status_label.setText("Playhead is not over a clip to split.") + return + clip_to_split = clips_at_playhead[0] + + old_state = self._get_current_timeline_state() + linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_split.group_id and c.id != clip_to_split.id), None) - new_right_side_group_id = str(uuid.uuid4()) split1 = self._split_at_time(clip_to_split, playhead_time, new_group_id=new_right_side_group_id) - split2 = True if linked_clip: - split2 = self._split_at_time(linked_clip, playhead_time, new_group_id=new_right_side_group_id) + self._split_at_time(linked_clip, playhead_time, new_group_id=new_right_side_group_id) - if split1 and split2: - self.timeline_widget.update() - self.status_label.setText("Clip split.") + if split1: + new_state = self._get_current_timeline_state() + command = TimelineStateChangeCommand("Split Clip", self.timeline, *old_state, *new_state) + command.undo() + self.undo_stack.push(command) else: self.status_label.setText("Failed to split clip.") - def on_split_region(self, region): - start_sec, end_sec = region; - clips = list(self.timeline.clips) - for clip in clips: self._split_at_time(clip, end_sec) - for clip in clips: self._split_at_time(clip, start_sec) - self.timeline_widget.clear_region(region); self.timeline_widget.update(); self.status_label.setText("Region split.") - def on_split_all_regions(self, regions): - split_points = set() - for start, end in regions: split_points.add(start); split_points.add(end) - for point in sorted(list(split_points)): - group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} - new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} - - for clip in list(self.timeline.clips): - if clip.group_id in new_group_ids: - self._split_at_time(clip, point, new_group_ids[clip.group_id]) + def delete_clip(self, clip_to_delete): + old_state = self._get_current_timeline_state() - self.timeline_widget.clear_all_regions(); self.timeline_widget.update(); self.status_label.setText("All regions split.") - - def on_join_region(self, region): - start_sec, end_sec = region; duration_to_remove = end_sec - start_sec - if duration_to_remove <= 0.01: return + linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_delete.group_id and c.id != clip_to_delete.id), None) - for point in [start_sec, end_sec]: - group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} - new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} - for clip in list(self.timeline.clips): - if clip.group_id in new_group_ids: - self._split_at_time(clip, point, new_group_ids[clip.group_id]) - - clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] - for clip in clips_to_remove: self.timeline.clips.remove(clip) - for clip in self.timeline.clips: - if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec); self.timeline_widget.clear_region(region) - self.timeline_widget.update(); self.status_label.setText("Region joined (content removed).") + if clip_to_delete in self.timeline.clips: self.timeline.clips.remove(clip_to_delete) + if linked_clip and linked_clip in self.timeline.clips: self.timeline.clips.remove(linked_clip) + + new_state = self._get_current_timeline_state() + command = TimelineStateChangeCommand("Delete Clip", self.timeline, *old_state, *new_state) + command.undo() + self.undo_stack.push(command) self.prune_empty_tracks() - def on_join_all_regions(self, regions): - for region in sorted(regions, key=lambda r: r[0], reverse=True): - start_sec, end_sec = region; duration_to_remove = end_sec - start_sec - if duration_to_remove <= 0.01: continue + def _perform_complex_timeline_change(self, description, change_function): + """Wrapper for complex operations to make them undoable.""" + old_state = self._get_current_timeline_state() + + # Execute the provided function which will modify the timeline + change_function() + + new_state = self._get_current_timeline_state() + + # If no change occurred, don't add to undo stack + if old_state[0] == new_state[0] and old_state[1] == new_state[1] and old_state[2] == new_state[2]: + return + + command = TimelineStateChangeCommand(description, self.timeline, *old_state, *new_state) + # The change was already made, so we need to "undo" it before letting the stack "redo" it. + command.undo() + self.undo_stack.push(command) - for point in [start_sec, end_sec]: + def on_split_region(self, region): + def action(): + start_sec, end_sec = region + clips = list(self.timeline.clips) # Work on a copy + for clip in clips: self._split_at_time(clip, end_sec) + for clip in clips: self._split_at_time(clip, start_sec) + self.timeline_widget.clear_region(region) + self._perform_complex_timeline_change("Split Region", action) + + def on_split_all_regions(self, regions): + def action(): + split_points = set() + for start, end in regions: + split_points.add(start) + split_points.add(end) + + for point in sorted(list(split_points)): group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) + self.timeline_widget.clear_all_regions() + self._perform_complex_timeline_change("Split All Regions", action) + + def on_join_region(self, region): + def action(): + start_sec, end_sec = region + duration_to_remove = end_sec - start_sec + if duration_to_remove <= 0.01: return + + # Split any clips that cross the region boundaries + for point in [start_sec, end_sec]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} + for clip in list(self.timeline.clips): + if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) + + # Remove clips fully inside the region + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + for clip in clips_to_remove: self.timeline.clips.remove(clip) + + # Ripple delete: move subsequent clips to the left + for clip in self.timeline.clips: + if clip.timeline_start_sec >= end_sec: + clip.timeline_start_sec -= duration_to_remove + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline_widget.clear_region(region) + self._perform_complex_timeline_change("Join Region", action) + + def on_join_all_regions(self, regions): + def action(): + # Process regions from end to start to avoid index issues + for region in sorted(regions, key=lambda r: r[0], reverse=True): + start_sec, end_sec = region + duration_to_remove = end_sec - start_sec + if duration_to_remove <= 0.01: continue + + # Split clips at boundaries + for point in [start_sec, end_sec]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} + for clip in list(self.timeline.clips): + if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) + + # Remove clips inside region + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + for clip in clips_to_remove: + try: self.timeline.clips.remove(clip) + except ValueError: pass # Already removed + + # Ripple + for clip in self.timeline.clips: + if clip.timeline_start_sec >= end_sec: + clip.timeline_start_sec -= duration_to_remove + + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline_widget.clear_all_regions() + self._perform_complex_timeline_change("Join All Regions", action) + + def on_delete_region(self, region): + def action(): + start_sec, end_sec = region + duration_to_remove = end_sec - start_sec + if duration_to_remove <= 0.01: return + + # Split any clips that cross the region boundaries + for point in [start_sec, end_sec]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} + for clip in list(self.timeline.clips): + if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) + + # Remove clips fully inside the region clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] - for clip in clips_to_remove: - try: self.timeline.clips.remove(clip) - except ValueError: pass + for clip in clips_to_remove: self.timeline.clips.remove(clip) + + # Ripple delete: move subsequent clips to the left for clip in self.timeline.clips: - if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove + if clip.timeline_start_sec >= end_sec: + clip.timeline_start_sec -= duration_to_remove + + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline_widget.clear_region(region) + self._perform_complex_timeline_change("Delete Region", action) + + def on_delete_all_regions(self, regions): + def action(): + # Process regions from end to start to avoid index issues + for region in sorted(regions, key=lambda r: r[0], reverse=True): + start_sec, end_sec = region + duration_to_remove = end_sec - start_sec + if duration_to_remove <= 0.01: continue + + # Split clips at boundaries + for point in [start_sec, end_sec]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} + for clip in list(self.timeline.clips): + if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) + + # Remove clips inside region + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + for clip in clips_to_remove: + try: self.timeline.clips.remove(clip) + except ValueError: pass # Already removed + + # Ripple + for clip in self.timeline.clips: + if clip.timeline_start_sec >= end_sec: + clip.timeline_start_sec -= duration_to_remove + + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline_widget.clear_all_regions() + self._perform_complex_timeline_change("Delete All Regions", action) - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec); self.timeline_widget.clear_all_regions() - self.timeline_widget.update(); self.status_label.setText("All regions joined.") - self.prune_empty_tracks() def export_video(self): if not self.timeline.clips: self.status_label.setText("Timeline is empty."); return @@ -1174,45 +1649,29 @@ def on_thread_finished_cleanup(self): def add_dock_widget(self, plugin_instance, widget, title, area=Qt.DockWidgetArea.RightDockWidgetArea, show_on_creation=True): widget_key = f"plugin_{plugin_instance.name}_{title}".replace(' ', '_').lower() - dock = QDockWidget(title, self) dock.setWidget(widget) self.addDockWidget(area, dock) - visibility_settings = self.settings.get("window_visibility", {}) initial_visibility = visibility_settings.get(widget_key, show_on_creation) - dock.setVisible(initial_visibility) - action = QAction(title, self, checkable=True) - action.toggled.connect(lambda checked, k=widget_key: self.toggle_widget_visibility(k, checked)) + action.toggled.connect(dock.setVisible) dock.visibilityChanged.connect(action.setChecked) - action.setChecked(dock.isVisible()) - self.windows_menu.addAction(action) - - self.managed_widgets[widget_key] = { - 'widget': dock, - 'name': title, - 'action': action, - 'plugin': plugin_instance.name - } - + self.managed_widgets[widget_key] = {'widget': dock, 'name': title, 'action': action, 'plugin': plugin_instance.name} return dock def update_plugin_ui_visibility(self, plugin_name, is_enabled): for key, data in self.managed_widgets.items(): if data.get('plugin') == plugin_name: data['action'].setVisible(is_enabled) - if not is_enabled: - data['widget'].hide() + if not is_enabled: data['widget'].hide() def toggle_plugin(self, name, checked): - if checked: - self.plugin_manager.enable_plugin(name) - else: - self.plugin_manager.disable_plugin(name) + if checked: self.plugin_manager.enable_plugin(name) + else: self.plugin_manager.disable_plugin(name) self._save_settings() def toggle_plugin_action(self, name, checked): @@ -1228,6 +1687,27 @@ def open_manage_plugins_dialog(self): dialog.exec() def closeEvent(self, event): + if self.settings.get("confirm_on_exit", True): + msg_box = QMessageBox(self) + msg_box.setWindowTitle("Confirm Exit") + msg_box.setText("Are you sure you want to exit?") + msg_box.setInformativeText("Any unsaved changes will be lost.") + msg_box.setIcon(QMessageBox.Icon.Question) + msg_box.setStandardButtons(QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No) + msg_box.setDefaultButton(QMessageBox.StandardButton.No) + + dont_ask_cb = QCheckBox("Don't ask again") + msg_box.setCheckBox(dont_ask_cb) + + reply = msg_box.exec() + + if dont_ask_cb.isChecked(): + self.settings['confirm_on_exit'] = False + + if reply == QMessageBox.StandardButton.No: + event.ignore() + return + self.is_shutting_down = True self._save_settings() self._stop_playback_stream() diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index 873cd5e56..56b94efe6 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -139,6 +139,7 @@ def generate(self): if response.status_code == 200: if payload['start_generation']: self.status_label.setText("Status: Parameters set. Generation sent. Polling...") + self.start_polling() else: self.status_label.setText("Status: Parameters set. Waiting for manual start.") else: @@ -154,6 +155,7 @@ def poll_for_output(self): latest_path = data.get("latest_output_path") if latest_path and latest_path != self.last_known_output: + self.stop_polling() self.last_known_output = latest_path self.output_label.setText(f"Latest Output File:\n{latest_path}") self.status_label.setText("Status: New output received! Inserting clip...") @@ -182,7 +184,6 @@ def enable(self): self.dock_widget.hide() self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu) - self.client_widget.start_polling() def disable(self): try: @@ -216,6 +217,8 @@ def setup_generator_for_region(self, region): self.active_region = region start_sec, end_sec = region + self.client_widget.check_server_status() + start_data, w, h = self.app.get_frame_data_at_time(start_sec) end_data, _, _ = self.app.get_frame_data_at_time(end_sec) diff --git a/videoeditor/undo.py b/videoeditor/undo.py new file mode 100644 index 000000000..b43f4fb46 --- /dev/null +++ b/videoeditor/undo.py @@ -0,0 +1,114 @@ +import copy +from PyQt6.QtCore import QObject, pyqtSignal + +class UndoCommand: + def __init__(self, description=""): + self.description = description + + def undo(self): + raise NotImplementedError + + def redo(self): + raise NotImplementedError + +class CompositeCommand(UndoCommand): + def __init__(self, description, commands): + super().__init__(description) + self.commands = commands + + def undo(self): + for cmd in reversed(self.commands): + cmd.undo() + + def redo(self): + for cmd in self.commands: + cmd.redo() + +class UndoStack(QObject): + history_changed = pyqtSignal() + timeline_changed = pyqtSignal() + + def __init__(self): + super().__init__() + self.undo_stack = [] + self.redo_stack = [] + + def push(self, command): + self.undo_stack.append(command) + self.redo_stack.clear() + command.redo() + self.history_changed.emit() + self.timeline_changed.emit() + + def undo(self): + if not self.can_undo(): + return + command = self.undo_stack.pop() + self.redo_stack.append(command) + command.undo() + self.history_changed.emit() + self.timeline_changed.emit() + + def redo(self): + if not self.can_redo(): + return + command = self.redo_stack.pop() + self.undo_stack.append(command) + command.redo() + self.history_changed.emit() + self.timeline_changed.emit() + + def can_undo(self): + return bool(self.undo_stack) + + def can_redo(self): + return bool(self.redo_stack) + + def undo_text(self): + return self.undo_stack[-1].description if self.can_undo() else "" + + def redo_text(self): + return self.redo_stack[-1].description if self.can_redo() else "" + +class TimelineStateChangeCommand(UndoCommand): + def __init__(self, description, timeline_model, old_clips, old_v_tracks, old_a_tracks, new_clips, new_v_tracks, new_a_tracks): + super().__init__(description) + self.timeline = timeline_model + self.old_clips_state = old_clips + self.old_v_tracks = old_v_tracks + self.old_a_tracks = old_a_tracks + self.new_clips_state = new_clips + self.new_v_tracks = new_v_tracks + self.new_a_tracks = new_a_tracks + + def undo(self): + self.timeline.clips = self.old_clips_state + self.timeline.num_video_tracks = self.old_v_tracks + self.timeline.num_audio_tracks = self.old_a_tracks + + def redo(self): + self.timeline.clips = self.new_clips_state + self.timeline.num_video_tracks = self.new_v_tracks + self.timeline.num_audio_tracks = self.new_a_tracks + + +class MoveClipsCommand(UndoCommand): + def __init__(self, description, timeline_model, move_data): + super().__init__(description) + self.timeline = timeline_model + self.move_data = move_data + + def _apply_state(self, state_key_prefix): + for data in self.move_data: + clip_id = data['clip_id'] + clip = next((c for c in self.timeline.clips if c.id == clip_id), None) + if clip: + clip.timeline_start_sec = data[f'{state_key_prefix}_start'] + clip.track_index = data[f'{state_key_prefix}_track'] + self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + + def undo(self): + self._apply_state('old') + + def redo(self): + self._apply_state('new') \ No newline at end of file From a21356f4f5aa6fe2a68adf19b33f1929a6bdd24a Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 08:27:09 +1100 Subject: [PATCH 092/155] add ability to drag tracks onto ghost tracks --- videoeditor/main.py | 177 +++++++++++++++++++++++++------------------- 1 file changed, 99 insertions(+), 78 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index a3f155e09..561244779 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -96,7 +96,6 @@ class TimelineWidget(QWidget): remove_track = pyqtSignal(str) operation_finished = pyqtSignal() context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent') - clips_moved = pyqtSignal(list) def __init__(self, timeline_model, settings, parent=None): super().__init__(parent) @@ -116,8 +115,10 @@ def __init__(self, timeline_model, settings, parent=None): self.drag_original_clip_states = {} # Store {'clip_id': (start_sec, track_index)} self.selection_drag_start_sec = 0.0 self.drag_selection_start_values = None + self.drag_start_state = None self.highlighted_track_info = None + self.highlighted_ghost_track_info = None self.add_video_track_btn_rect = QRect() self.remove_video_track_btn_rect = QRect() self.add_audio_track_btn_rect = QRect() @@ -292,6 +293,18 @@ def draw_tracks_and_clips(self, painter): highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) painter.fillRect(highlight_rect, QColor(255, 255, 0, 40)) + if self.highlighted_ghost_track_info: + track_type, track_index = self.highlighted_ghost_track_info + y = -1 + if track_type == 'video': + y = self.TIMESCALE_HEIGHT + elif track_type == 'audio': + y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT + + if y != -1: + highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) + painter.fillRect(highlight_rect, QColor(255, 255, 0, 40)) + # Draw clips for clip in self.timeline.clips: clip_rect = self.get_clip_rect(clip) @@ -323,17 +336,30 @@ def draw_playhead(self, painter): painter.drawLine(playhead_x, 0, playhead_x, self.height()) def y_to_track_info(self, y): + # Check for ghost video track + if self.TIMESCALE_HEIGHT <= y < self.video_tracks_y_start: + return ('video', self.timeline.num_video_tracks + 1) + + # Check for existing video tracks video_tracks_end_y = self.video_tracks_y_start + self.timeline.num_video_tracks * self.TRACK_HEIGHT if self.video_tracks_y_start <= y < video_tracks_end_y: visual_index = (y - self.video_tracks_y_start) // self.TRACK_HEIGHT track_index = self.timeline.num_video_tracks - visual_index return ('video', track_index) + # Check for existing audio tracks audio_tracks_end_y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT if self.audio_tracks_y_start <= y < audio_tracks_end_y: visual_index = (y - self.audio_tracks_y_start) // self.TRACK_HEIGHT track_index = visual_index + 1 return ('audio', track_index) + + # Check for ghost audio track + add_audio_btn_y_start = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT + add_audio_btn_y_end = add_audio_btn_y_start + self.TRACK_HEIGHT + if add_audio_btn_y_start <= y < add_audio_btn_y_end: + return ('audio', self.timeline.num_audio_tracks + 1) + return None def get_region_at_pos(self, pos: QPoint): @@ -367,6 +393,7 @@ def mousePressEvent(self, event: QMouseEvent): clip_rect = self.get_clip_rect(clip) if clip_rect.contains(QPointF(event.pos())): self.dragging_clip = clip + self.drag_start_state = self.window()._get_current_timeline_state() self.drag_original_clip_states[clip.id] = (clip.timeline_start_sec, clip.track_index) self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id), None) @@ -427,24 +454,29 @@ def mouseMoveEvent(self, event: QMouseEvent): self.update() elif self.dragging_clip: self.highlighted_track_info = None + self.highlighted_ghost_track_info = None new_track_info = self.y_to_track_info(event.pos().y()) original_start_sec, _ = self.drag_original_clip_states[self.dragging_clip.id] if new_track_info: new_track_type, new_track_index = new_track_info + + is_ghost_track = (new_track_type == 'video' and new_track_index > self.timeline.num_video_tracks) or \ + (new_track_type == 'audio' and new_track_index > self.timeline.num_audio_tracks) + + if is_ghost_track: + self.highlighted_ghost_track_info = new_track_info + else: + self.highlighted_track_info = new_track_info + if new_track_type == self.dragging_clip.track_type: self.dragging_clip.track_index = new_track_index - if self.dragging_linked_clip: - # For now, linked clips move to same-number track index - self.dragging_linked_clip.track_index = new_track_index - delta_x = event.pos().x() - self.drag_start_pos.x() time_delta = delta_x / self.PIXELS_PER_SECOND new_start_time = original_start_sec + time_delta - # Basic collision detection (can be improved) for other_clip in self.timeline.clips: if other_clip.id == self.dragging_clip.id: continue if self.dragging_linked_clip and other_clip.id == self.dragging_linked_clip.id: continue @@ -482,35 +514,27 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.dragging_playhead = False if self.dragging_clip: - move_data = [] - # Check main dragged clip orig_start, orig_track = self.drag_original_clip_states[self.dragging_clip.id] - if orig_start != self.dragging_clip.timeline_start_sec or orig_track != self.dragging_clip.track_index: - move_data.append({ - 'clip_id': self.dragging_clip.id, - 'old_start': orig_start, 'new_start': self.dragging_clip.timeline_start_sec, - 'old_track': orig_track, 'new_track': self.dragging_clip.track_index - }) - # Check linked clip + moved = (orig_start != self.dragging_clip.timeline_start_sec or + orig_track != self.dragging_clip.track_index) + if self.dragging_linked_clip: orig_start_link, orig_track_link = self.drag_original_clip_states[self.dragging_linked_clip.id] - if orig_start_link != self.dragging_linked_clip.timeline_start_sec or orig_track_link != self.dragging_linked_clip.track_index: - move_data.append({ - 'clip_id': self.dragging_linked_clip.id, - 'old_start': orig_start_link, 'new_start': self.dragging_linked_clip.timeline_start_sec, - 'old_track': orig_track_link, 'new_track': self.dragging_linked_clip.track_index - }) + moved = moved or (orig_start_link != self.dragging_linked_clip.timeline_start_sec or + orig_track_link != self.dragging_linked_clip.track_index) + + if moved: + self.window().finalize_clip_drag(self.drag_start_state) - if move_data: - self.clips_moved.emit(move_data) - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) self.highlighted_track_info = None + self.highlighted_ghost_track_info = None self.operation_finished.emit() self.dragging_clip = None self.dragging_linked_clip = None self.drag_original_clip_states.clear() + self.drag_start_state = None self.update() @@ -522,6 +546,8 @@ def dragEnterEvent(self, event): def dragLeaveEvent(self, event): self.drag_over_active = False + self.highlighted_ghost_track_info = None + self.highlighted_track_info = None self.update() def dragMoveEvent(self, event): @@ -543,6 +569,15 @@ def dragMoveEvent(self, event): if track_info: self.drag_over_active = True track_type, track_index = track_info + + is_ghost_track = (track_type == 'video' and track_index > self.timeline.num_video_tracks) or \ + (track_type == 'audio' and track_index > self.timeline.num_audio_tracks) + if is_ghost_track: + self.highlighted_ghost_track_info = track_info + self.highlighted_track_info = None + else: + self.highlighted_ghost_track_info = None + self.highlighted_track_info = track_info width = int(duration * self.PIXELS_PER_SECOND) x = self.sec_to_x(start_sec) @@ -564,11 +599,15 @@ def dragMoveEvent(self, event): self.drag_over_rect = QRectF(x, video_y, width, self.TRACK_HEIGHT) else: self.drag_over_active = False + self.highlighted_ghost_track_info = None + self.highlighted_track_info = None self.update() def dropEvent(self, event): self.drag_over_active = False + self.highlighted_ghost_track_info = None + self.highlighted_track_info = None self.update() mime_data = event.mimeData() @@ -775,7 +814,7 @@ def __init__(self, project_to_load=None): self.timeline = Timeline() self.undo_stack = UndoStack() self.media_pool = [] - self.media_properties = {} # To store duration, has_audio etc. + self.media_properties = {} self.export_thread = None self.export_worker = None self.current_project_path = None @@ -805,7 +844,6 @@ def __init__(self, project_to_load=None): if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load)) def _get_current_timeline_state(self): - """Helper to capture a snapshot of the timeline for undo/redo.""" return ( copy.deepcopy(self.timeline.clips), self.timeline.num_video_tracks, @@ -886,7 +924,6 @@ def _connect_signals(self): self.timeline_widget.split_requested.connect(self.split_clip_at_playhead) self.timeline_widget.delete_clip_requested.connect(self.delete_clip) self.timeline_widget.playhead_moved.connect(self.seek_preview) - self.timeline_widget.clips_moved.connect(self.on_clips_moved) self.timeline_widget.split_region_requested.connect(self.on_split_region) self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions) self.timeline_widget.join_region_requested.connect(self.on_join_region) @@ -911,10 +948,8 @@ def _connect_signals(self): self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo) def on_timeline_changed_by_undo(self): - """Called when undo/redo modifies the timeline.""" self.prune_empty_tracks() self.timeline_widget.update() - # You may want to update other UI elements here as well self.status_label.setText("Operation undone/redone.") def update_undo_redo_actions(self): @@ -924,13 +959,23 @@ def update_undo_redo_actions(self): self.redo_action.setEnabled(self.undo_stack.can_redo()) self.redo_action.setText(f"Redo {self.undo_stack.redo_text()}" if self.undo_stack.can_redo() else "Redo") - def on_clips_moved(self, move_data): - command = MoveClipsCommand("Move Clip", self.timeline, move_data) - # The change is already applied visually, so we push without redoing - self.undo_stack.undo_stack.append(command) - self.undo_stack.redo_stack.clear() - self.undo_stack.history_changed.emit() - self.status_label.setText("Clip moved.") + def finalize_clip_drag(self, old_state_tuple): + current_clips, _, _ = self._get_current_timeline_state() + + max_v_idx = max([c.track_index for c in current_clips if c.track_type == 'video'] + [1]) + max_a_idx = max([c.track_index for c in current_clips if c.track_type == 'audio'] + [1]) + + if max_v_idx > self.timeline.num_video_tracks: + self.timeline.num_video_tracks = max_v_idx + + if max_a_idx > self.timeline.num_audio_tracks: + self.timeline.num_audio_tracks = max_a_idx + + new_state_tuple = self._get_current_timeline_state() + + command = TimelineStateChangeCommand("Move Clip", self.timeline, *old_state_tuple, *new_state_tuple) + command.undo() + self.undo_stack.push(command) def on_add_to_timeline_at_playhead(self, file_path): media_info = self.media_properties.get(file_path) @@ -953,7 +998,6 @@ def on_add_to_timeline_at_playhead(self, file_path): def prune_empty_tracks(self): pruned_something = False - # Prune Video Tracks while self.timeline.num_video_tracks > 1: highest_track_index = self.timeline.num_video_tracks is_track_occupied = any(c for c in self.timeline.clips @@ -963,8 +1007,7 @@ def prune_empty_tracks(self): else: self.timeline.num_video_tracks -= 1 pruned_something = True - - # Prune Audio Tracks + while self.timeline.num_audio_tracks > 1: highest_track_index = self.timeline.num_audio_tracks is_track_occupied = any(c for c in self.timeline.clips @@ -988,7 +1031,6 @@ def add_track(self, track_type): command = TimelineStateChangeCommand(f"Add {track_type.capitalize()} Track", self.timeline, *old_state, *new_state) self.undo_stack.push(command) - # Manually undo the direct change, because push() will redo() it. command.undo() self.undo_stack.push(command) @@ -999,11 +1041,11 @@ def remove_track(self, track_type): elif track_type == 'audio' and self.timeline.num_audio_tracks > 1: self.timeline.num_audio_tracks -= 1 else: - return # No change made + return new_state = self._get_current_timeline_state() command = TimelineStateChangeCommand(f"Remove {track_type.capitalize()} Track", self.timeline, *old_state, *new_state) - command.undo() # Invert state because we already made the change + command.undo() self.undo_stack.push(command) @@ -1031,7 +1073,7 @@ def _create_menu_bar(self): split_action = QAction("Split Clip at Playhead", self); split_action.triggered.connect(self.split_clip_at_playhead) edit_menu.addAction(split_action) - self.update_undo_redo_actions() # Set initial state + self.update_undo_redo_actions() plugins_menu = menu_bar.addMenu("&Plugins") for name, data in self.plugin_manager.plugins.items(): @@ -1048,7 +1090,7 @@ def _create_menu_bar(self): self.windows_menu = menu_bar.addMenu("&Windows") for key, data in self.managed_widgets.items(): - if data['widget'] is self.preview_widget: continue # Preview is part of splitter, not a dock + if data['widget'] is self.preview_widget: continue action = QAction(data['name'], self, checkable=True) if hasattr(data['widget'], 'visibilityChanged'): action.toggled.connect(data['widget'].setVisible) @@ -1095,10 +1137,6 @@ def _set_project_properties_from_clip(self, source_path): return False def get_frame_data_at_time(self, time_sec): - """ - Plugin API: Extracts raw frame data at a specific time. - Returns a tuple of (bytes, width, height) or (None, 0, 0) on failure. - """ clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) if not clip_at_time: return (None, 0, 0) @@ -1201,7 +1239,7 @@ def _apply_loaded_settings(self): for key, data in self.managed_widgets.items(): if data.get('plugin'): continue is_visible = visibility_settings.get(key, True) - if data['widget'] is not self.preview_widget: # preview handled by splitter state + if data['widget'] is not self.preview_widget: data['widget'].setVisible(is_visible) if data['action']: data['action'].setChecked(is_visible) splitter_state = self.settings.get("splitter_state") @@ -1250,7 +1288,7 @@ def open_project(self): def _load_project_from_path(self, path): try: with open(path, "r") as f: project_data = json.load(f) - self.new_project() # Clear existing state + self.new_project() self.media_pool = project_data.get("media_pool", []) for p in self.media_pool: self._add_media_to_pool(p) @@ -1343,16 +1381,19 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, c group_id = str(uuid.uuid4()) if video_track_index is not None: + if video_track_index > self.timeline.num_video_tracks: + self.timeline.num_video_tracks = video_track_index video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', group_id) self.timeline.add_clip(video_clip) if audio_track_index is not None: + if audio_track_index > self.timeline.num_audio_tracks: + self.timeline.num_audio_tracks = audio_track_index audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', group_id) self.timeline.add_clip(audio_clip) new_state = self._get_current_timeline_state() command = TimelineStateChangeCommand("Add Clip", self.timeline, *old_state, *new_state) - # We already performed the action, so we need to undo it before letting the stack redo it. command.undo() self.undo_stack.push(command) @@ -1409,27 +1450,21 @@ def delete_clip(self, clip_to_delete): self.prune_empty_tracks() def _perform_complex_timeline_change(self, description, change_function): - """Wrapper for complex operations to make them undoable.""" old_state = self._get_current_timeline_state() - - # Execute the provided function which will modify the timeline change_function() new_state = self._get_current_timeline_state() - - # If no change occurred, don't add to undo stack if old_state[0] == new_state[0] and old_state[1] == new_state[1] and old_state[2] == new_state[2]: return command = TimelineStateChangeCommand(description, self.timeline, *old_state, *new_state) - # The change was already made, so we need to "undo" it before letting the stack "redo" it. command.undo() self.undo_stack.push(command) def on_split_region(self, region): def action(): start_sec, end_sec = region - clips = list(self.timeline.clips) # Work on a copy + clips = list(self.timeline.clips) for clip in clips: self._split_at_time(clip, end_sec) for clip in clips: self._split_at_time(clip, start_sec) self.timeline_widget.clear_region(region) @@ -1457,18 +1492,15 @@ def action(): duration_to_remove = end_sec - start_sec if duration_to_remove <= 0.01: return - # Split any clips that cross the region boundaries for point in [start_sec, end_sec]: group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - # Remove clips fully inside the region clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] for clip in clips_to_remove: self.timeline.clips.remove(clip) - # Ripple delete: move subsequent clips to the left for clip in self.timeline.clips: if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove @@ -1479,26 +1511,21 @@ def action(): def on_join_all_regions(self, regions): def action(): - # Process regions from end to start to avoid index issues for region in sorted(regions, key=lambda r: r[0], reverse=True): start_sec, end_sec = region duration_to_remove = end_sec - start_sec if duration_to_remove <= 0.01: continue - - # Split clips at boundaries + for point in [start_sec, end_sec]: group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - - # Remove clips inside region clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] for clip in clips_to_remove: try: self.timeline.clips.remove(clip) - except ValueError: pass # Already removed - - # Ripple + except ValueError: pass + for clip in self.timeline.clips: if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove @@ -1513,18 +1540,15 @@ def action(): duration_to_remove = end_sec - start_sec if duration_to_remove <= 0.01: return - # Split any clips that cross the region boundaries for point in [start_sec, end_sec]: group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - # Remove clips fully inside the region clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] for clip in clips_to_remove: self.timeline.clips.remove(clip) - # Ripple delete: move subsequent clips to the left for clip in self.timeline.clips: if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove @@ -1535,24 +1559,21 @@ def action(): def on_delete_all_regions(self, regions): def action(): - # Process regions from end to start to avoid index issues for region in sorted(regions, key=lambda r: r[0], reverse=True): start_sec, end_sec = region duration_to_remove = end_sec - start_sec if duration_to_remove <= 0.01: continue - - # Split clips at boundaries + for point in [start_sec, end_sec]: group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - - # Remove clips inside region + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] for clip in clips_to_remove: try: self.timeline.clips.remove(clip) - except ValueError: pass # Already removed + except ValueError: pass # Ripple for clip in self.timeline.clips: From 354c0d1a40126733bf074a65230148d351504355 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 08:33:02 +1100 Subject: [PATCH 093/155] fix bug dragging ghost tracks from project media --- videoeditor/main.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 561244779..34d9450c2 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -290,7 +290,7 @@ def draw_tracks_and_clips(self, painter): y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT if y != -1: - highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) + highlight_rect = QRect(self.HEADER_WIDTH, int(y), self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) painter.fillRect(highlight_rect, QColor(255, 255, 0, 40)) if self.highlighted_ghost_track_info: @@ -302,10 +302,9 @@ def draw_tracks_and_clips(self, painter): y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT if y != -1: - highlight_rect = QRect(self.HEADER_WIDTH, y, self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) + highlight_rect = QRect(self.HEADER_WIDTH, int(y), self.width() - self.HEADER_WIDTH, self.TRACK_HEIGHT) painter.fillRect(highlight_rect, QColor(255, 255, 0, 40)) - # Draw clips for clip in self.timeline.clips: clip_rect = self.get_clip_rect(clip) base_color = QColor("#46A") if clip.track_type == 'video' else QColor("#48C") From 97f536d3cdb05917d5d8e50deae905718df09f34 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 10 Oct 2025 01:27:39 +0200 Subject: [PATCH 094/155] more fp8 formats support --- models/wan/any2video.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/models/wan/any2video.py b/models/wan/any2video.py index 7e0592b79..d2f678060 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -59,14 +59,20 @@ def timestep_transform(t, shift=5.0, num_timesteps=1000 ): new_t = new_t * num_timesteps return new_t -def preprocess_sd(sd): +def preprocess_sd_with_dtype(dtype, sd): new_sd = {} prefix_list = ["model.diffusion_model"] + end_list = [".norm3.bias", ".norm3.weight", ".norm_q.bias", ".norm_q.weight", ".norm_k.bias", ".norm_k.weight" ] for k,v in sd.items(): for prefix in prefix_list: if k.startswith(prefix): k = k[len(prefix)+1:] - continue + break + if v.dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + for endfix in end_list: + if k.endswith(endfix): + v = v.to(dtype) + break if not k.startswith("vae."): new_sd[k] = v return new_sd @@ -143,6 +149,8 @@ def __init__( source2 = model_def.get("source2", None) module_source = model_def.get("module_source", None) module_source2 = model_def.get("module_source2", None) + def preprocess_sd(sd): + return preprocess_sd_with_dtype(dtype, sd) kwargs= { "modelClass": WanModel,"do_quantize": quantizeTransformer and not save_quantized, "defaultConfigPath": base_config_file , "ignore_unused_weights": ignore_unused_weights, "writable_tensors": False, "default_dtype": dtype, "preprocess_sd": preprocess_sd, "forcedConfigPath": forcedConfigPath, } kwargs_light= { "modelClass": WanModel,"writable_tensors": False, "preprocess_sd": preprocess_sd , "forcedConfigPath" : base_config_file} if module_source is not None: From cad961803f499533c9cb6fa95424dbb425699bb7 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 12:00:08 +1100 Subject: [PATCH 095/155] add dynamic scroll and dynamic timescale --- videoeditor/main.py | 166 +++++++++++++++++++++++++++++++++++++------- 1 file changed, 139 insertions(+), 27 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 34d9450c2..6aee83fd6 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -77,7 +77,6 @@ def run_export(self): self.finished.emit(f"An exception occurred during export: {e}") class TimelineWidget(QWidget): - PIXELS_PER_SECOND = 50 TIMESCALE_HEIGHT = 30 HEADER_WIDTH = 120 TRACK_HEIGHT = 50 @@ -97,11 +96,18 @@ class TimelineWidget(QWidget): operation_finished = pyqtSignal() context_menu_requested = pyqtSignal(QMenu, 'QContextMenuEvent') - def __init__(self, timeline_model, settings, parent=None): + def __init__(self, timeline_model, settings, project_fps, parent=None): super().__init__(parent) self.timeline = timeline_model self.settings = settings self.playhead_pos_sec = 0.0 + self.scroll_area = None + + self.pixels_per_second = 50.0 + self.max_pixels_per_second = 1.0 # Will be updated by set_project_fps + self.project_fps = 25.0 + self.set_project_fps(project_fps) + self.setMinimumHeight(300) self.setMouseTracking(True) self.setAcceptDrops(True) @@ -132,8 +138,15 @@ def __init__(self, timeline_model, settings, parent=None): self.drag_over_audio_rect = QRectF() self.drag_has_audio = False - def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.PIXELS_PER_SECOND) - def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.PIXELS_PER_SECOND if x > self.HEADER_WIDTH else 0.0 + def set_project_fps(self, fps): + self.project_fps = fps if fps > 0 else 25.0 + # Set max zoom to be 10 pixels per frame + self.max_pixels_per_second = self.project_fps * 20 + self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second) + self.update() + + def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.pixels_per_second) + def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else 0.0 def paintEvent(self, event): painter = QPainter(self) @@ -156,7 +169,7 @@ def paintEvent(self, event): self.draw_playhead(painter) - total_width = self.sec_to_x(self.timeline.get_total_duration()) + 100 + total_width = self.sec_to_x(self.timeline.get_total_duration()) + 200 total_height = self.calculate_total_height() self.setMinimumSize(max(self.parent().width(), total_width), total_height) @@ -225,30 +238,82 @@ def draw_headers(self, painter): painter.drawText(self.add_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)") painter.restore() + + def _format_timecode(self, seconds): + if abs(seconds) < 1e-9: seconds = 0 + sign = "-" if seconds < 0 else "" + seconds = abs(seconds) + + h = int(seconds / 3600) + m = int((seconds % 3600) / 60) + s = seconds % 60 + + if h > 0: return f"{sign}{h}h:{m:02d}m" + if m > 0 or seconds >= 59.99: return f"{sign}{m}m:{int(round(s)):02d}s" + + precision = 2 if s < 1 else 1 if s < 10 else 0 + val = f"{s:.{precision}f}" + # Remove trailing .0 or .00 + if '.' in val: val = val.rstrip('0').rstrip('.') + return f"{sign}{val}s" def draw_timescale(self, painter): painter.save() painter.setPen(QColor("#AAA")) painter.setFont(QFont("Arial", 8)) font_metrics = QFontMetrics(painter.font()) - + painter.fillRect(QRect(self.HEADER_WIDTH, 0, self.width() - self.HEADER_WIDTH, self.TIMESCALE_HEIGHT), QColor("#222")) painter.drawLine(self.HEADER_WIDTH, self.TIMESCALE_HEIGHT - 1, self.width(), self.TIMESCALE_HEIGHT - 1) - max_time_to_draw = self.x_to_sec(self.width()) - major_interval_sec = 5 - minor_interval_sec = 1 + frame_dur = 1.0 / self.project_fps + intervals = [ + frame_dur, 2*frame_dur, 5*frame_dur, 10*frame_dur, + 0.1, 0.2, 0.5, 1, 2, 5, 10, 15, 30, + 60, 120, 300, 600, 900, 1800, + 3600, 2*3600, 5*3600, 10*3600 + ] + + min_pixel_dist = 70 + major_interval = next((i for i in intervals if i * self.pixels_per_second > min_pixel_dist), intervals[-1]) - for t_sec in range(int(max_time_to_draw) + 2): + minor_interval = 0 + for divisor in [5, 4, 2]: + if (major_interval / divisor) * self.pixels_per_second > 10: + minor_interval = major_interval / divisor + break + + start_sec = self.x_to_sec(self.HEADER_WIDTH) + end_sec = self.x_to_sec(self.width()) + + def draw_ticks(interval, height): + if interval <= 1e-6: return + start_tick_num = int(start_sec / interval) + end_tick_num = int(end_sec / interval) + 1 + for i in range(start_tick_num, end_tick_num + 1): + t_sec = i * interval + x = self.sec_to_x(t_sec) + if x > self.width(): break + if x >= self.HEADER_WIDTH: + painter.drawLine(x, self.TIMESCALE_HEIGHT - height, x, self.TIMESCALE_HEIGHT) + + if frame_dur * self.pixels_per_second > 4: + draw_ticks(frame_dur, 3) + if minor_interval > 0: + draw_ticks(minor_interval, 6) + + start_major_tick = int(start_sec / major_interval) + end_major_tick = int(end_sec / major_interval) + 1 + for i in range(start_major_tick, end_major_tick + 1): + t_sec = i * major_interval x = self.sec_to_x(t_sec) - - if t_sec % major_interval_sec == 0: - painter.drawLine(x, self.TIMESCALE_HEIGHT - 10, x, self.TIMESCALE_HEIGHT - 1) - label = f"{t_sec}s" + if x > self.width() + 50: break + if x >= self.HEADER_WIDTH - 50: + painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT) + label = self._format_timecode(t_sec) label_width = font_metrics.horizontalAdvance(label) - painter.drawText(x - label_width // 2, self.TIMESCALE_HEIGHT - 12, label) - elif t_sec % minor_interval_sec == 0: - painter.drawLine(x, self.TIMESCALE_HEIGHT - 5, x, self.TIMESCALE_HEIGHT - 1) + painter.drawText(x - label_width // 2, self.TIMESCALE_HEIGHT - 14, label) + painter.restore() def get_clip_rect(self, clip): @@ -260,7 +325,7 @@ def get_clip_rect(self, clip): y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT x = self.sec_to_x(clip.timeline_start_sec) - w = int(clip.duration_sec * self.PIXELS_PER_SECOND) + w = int(clip.duration_sec * self.pixels_per_second) clip_height = self.TRACK_HEIGHT - 10 y += (self.TRACK_HEIGHT - clip_height) / 2 return QRectF(x, y, w, clip_height) @@ -323,7 +388,7 @@ def draw_tracks_and_clips(self, painter): def draw_selections(self, painter): for start_sec, end_sec in self.selection_regions: x = self.sec_to_x(start_sec) - w = int((end_sec - start_sec) * self.PIXELS_PER_SECOND) + w = int((end_sec - start_sec) * self.pixels_per_second) selection_rect = QRectF(x, self.TIMESCALE_HEIGHT, w, self.height() - self.TIMESCALE_HEIGHT) painter.fillRect(selection_rect, QColor(100, 100, 255, 80)) painter.setPen(QColor(150, 150, 255, 150)) @@ -371,6 +436,46 @@ def get_region_at_pos(self, pos: QPoint): return region return None + def wheelEvent(self, event: QMouseEvent): + if not self.scroll_area or event.position().x() < self.HEADER_WIDTH: + event.ignore() + return + + scrollbar = self.scroll_area.horizontalScrollBar() + mouse_x_abs = event.position().x() + + time_at_cursor = self.x_to_sec(mouse_x_abs) + + delta = event.angleDelta().y() + zoom_factor = 1.15 + old_pps = self.pixels_per_second + + if delta > 0: new_pps = old_pps * zoom_factor + else: new_pps = old_pps / zoom_factor + + min_pps = 1 / (3600 * 10) # Min zoom of 1px per 10 hours + new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second)) + + if abs(new_pps - old_pps) < 1e-9: + return + + self.pixels_per_second = new_pps + # This will trigger a paintEvent, which resizes the widget. + # The scroll area will update its scrollbar ranges in response. + self.update() + + # The new absolute x-coordinate for the time under the cursor + new_mouse_x_abs = self.sec_to_x(time_at_cursor) + + # The amount the content "shifted" at the cursor's location + shift_amount = new_mouse_x_abs - mouse_x_abs + + # Adjust the scrollbar by this shift amount to keep the content under the cursor + new_scroll_value = scrollbar.value() + shift_amount + scrollbar.setValue(int(new_scroll_value)) + + event.accept() + def mousePressEvent(self, event: QMouseEvent): if event.button() == Qt.MouseButton.LeftButton: if event.pos().x() < self.HEADER_WIDTH: @@ -435,7 +540,7 @@ def mouseMoveEvent(self, event: QMouseEvent): if self.dragging_selection_region: delta_x = event.pos().x() - self.drag_start_pos.x() - time_delta = delta_x / self.PIXELS_PER_SECOND + time_delta = delta_x / self.pixels_per_second original_start, original_end = self.drag_selection_start_values duration = original_end - original_start @@ -473,7 +578,7 @@ def mouseMoveEvent(self, event: QMouseEvent): self.dragging_clip.track_index = new_track_index delta_x = event.pos().x() - self.drag_start_pos.x() - time_delta = delta_x / self.PIXELS_PER_SECOND + time_delta = delta_x / self.pixels_per_second new_start_time = original_start_sec + time_delta for other_clip in self.timeline.clips: @@ -504,7 +609,7 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.creating_selection_region = False if self.selection_regions: start, end = self.selection_regions[-1] - if (end - start) * self.PIXELS_PER_SECOND < 2: + if (end - start) * self.pixels_per_second < 2: self.selection_regions.pop() if self.dragging_selection_region: @@ -578,7 +683,7 @@ def dragMoveEvent(self, event): self.highlighted_ghost_track_info = None self.highlighted_track_info = track_info - width = int(duration * self.PIXELS_PER_SECOND) + width = int(duration * self.pixels_per_second) x = self.sec_to_x(start_sec) if track_type == 'video': @@ -864,9 +969,11 @@ def _setup_ui(self): self.preview_widget.setStyleSheet("background-color: black; color: white;") self.splitter.addWidget(self.preview_widget) - self.timeline_widget = TimelineWidget(self.timeline, self.settings, self) + self.timeline_widget = TimelineWidget(self.timeline, self.settings, self.project_fps, self) self.timeline_scroll_area = QScrollArea() - self.timeline_scroll_area.setWidgetResizable(True) + self.timeline_widget.scroll_area = self.timeline_scroll_area + self.timeline_scroll_area.setWidgetResizable(False) + self.timeline_widget.setMinimumWidth(2000) self.timeline_scroll_area.setWidget(self.timeline_widget) self.timeline_scroll_area.setFrameShape(QFrame.Shape.NoFrame) self.timeline_scroll_area.setMinimumHeight(250) @@ -1129,7 +1236,9 @@ def _set_project_properties_from_clip(self, source_path): self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height']) if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0': num, den = map(int, video_stream['r_frame_rate'].split('/')) - if den > 0: self.project_fps = num / den + if den > 0: + self.project_fps = num / den + self.timeline_widget.set_project_fps(self.project_fps) print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS") return True except Exception as e: print(f"Could not probe for project properties: {e}") @@ -1260,7 +1369,10 @@ def open_settings_dialog(self): def new_project(self): self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1 self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list() - self.current_project_path = None; self.stop_playback(); self.timeline_widget.update() + self.current_project_path = None; self.stop_playback() + self.project_fps = 25.0 + self.timeline_widget.set_project_fps(self.project_fps) + self.timeline_widget.update() self.undo_stack = UndoStack() self.undo_stack.history_changed.connect(self.update_undo_redo_actions) self.update_undo_redo_actions() From 8b3552f9d85f9ecb2eaf243310762126bcf3277f Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 10 Oct 2025 12:11:29 +0200 Subject: [PATCH 096/155] fixed bugs and added Depth control for Qwen --- models/qwen/qwen_handler.py | 2 +- models/wan/wan_handler.py | 2 +- wgp.py | 6 ++++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/models/qwen/qwen_handler.py b/models/qwen/qwen_handler.py index 2a3c8487a..0505a5768 100644 --- a/models/qwen/qwen_handler.py +++ b/models/qwen/qwen_handler.py @@ -46,7 +46,7 @@ def query_model_def(base_model_type, model_def): if base_model_type in ["qwen_image_edit_plus_20B"]: extra_model_def["guide_preprocessing"] = { - "selection": ["", "PV", "SV", "CV"], + "selection": ["", "PV", "DV", "SV", "CV"], } extra_model_def["mask_preprocessing"] = { diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index bb6745189..f70e376b4 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -396,7 +396,7 @@ def query_model_def(base_model_type, model_def): "visible" : False, } - if vace_class or base_model_type in ["animate"]: + if vace_class or base_model_type in ["animate", "t2v", "t2v_2_2"] : image_prompt_types_allowed = "TVL" elif base_model_type in ["infinitetalk"]: image_prompt_types_allowed = "TSVL" diff --git a/wgp.py b/wgp.py index 3bc3e54d2..43e61bad5 100644 --- a/wgp.py +++ b/wgp.py @@ -2646,8 +2646,9 @@ def computeList(filename): preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) for url in preload_URLs: - filename = fl.get_download_location(url.split("/")[-1]) - if not os.path.isfile(filename ): + filename = fl.locate_file(os.path.basename(url)) + if filename is None: + filename = fl.get_download_location(os.path.basename(url)) if not url.startswith("http"): raise Exception(f"File '{filename}' to preload was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") try: @@ -5206,6 +5207,7 @@ def remove_temp_filenames(temp_filenames_list): if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1)) if video_guide is not None: preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) + preview_frame_no = min(src_video.shape[1] -1, preview_frame_no) refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) if src_video2 is not None: refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)] From e3e0c0325513cc6dba3fb6667486cbf25fba2011 Mon Sep 17 00:00:00 2001 From: DeepBeepMeep Date: Fri, 10 Oct 2025 13:02:43 +0200 Subject: [PATCH 097/155] fixed ref images ignored with multiple samples --- models/wan/modules/model.py | 1 - wgp.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index 8c33a17a7..4f6d63e36 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -1339,7 +1339,6 @@ def forward( ): # patch_dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype - if self.model_type == 'i2v': assert clip_fea is not None and y is not None # params diff --git a/wgp.py b/wgp.py index 43e61bad5..74af14e4a 100644 --- a/wgp.py +++ b/wgp.py @@ -4846,6 +4846,7 @@ def remove_temp_filenames(temp_filenames_list): trans.cache = skip_steps_cache if trans2 is not None: trans2.cache = skip_steps_cache face_arc_embeds = None + src_ref_images = src_ref_masks = None output_new_audio_data = None output_new_audio_filepath = None original_audio_guide = audio_guide @@ -4940,7 +4941,7 @@ def remove_temp_filenames(temp_filenames_list): first_window_video_length = current_video_length original_prompts = prompts.copy() - gen["sliding_window"] = sliding_window + gen["sliding_window"] = sliding_window while not abort: extra_generation += gen.get("extra_orders",0) gen["extra_orders"] = 0 @@ -4950,7 +4951,7 @@ def remove_temp_filenames(temp_filenames_list): if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None + src_video = src_video2 = src_mask = src_mask2 = src_faces = sparse_video_image = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) From eacb49814b7b13845598377c41e785489670fbd9 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 10 Oct 2025 23:21:10 +1100 Subject: [PATCH 098/155] fix ai joiner to start polling for outputs during manual generate, stopped selector sliding around --- videoeditor/main.py | 51 +++++++++++---------- videoeditor/plugins/ai_frame_joiner/main.py | 12 ++--- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 6aee83fd6..0c08e7b9a 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -153,7 +153,11 @@ def paintEvent(self, event): painter.setRenderHint(QPainter.RenderHint.Antialiasing) painter.fillRect(self.rect(), QColor("#333")) - self.draw_headers(painter) + h_offset = 0 + if self.scroll_area and self.scroll_area.horizontalScrollBar(): + h_offset = self.scroll_area.horizontalScrollBar().value() + + self.draw_headers(painter, h_offset) self.draw_timescale(painter) self.draw_tracks_and_clips(painter) self.draw_selections(painter) @@ -178,7 +182,7 @@ def calculate_total_height(self): audio_tracks_height = (self.timeline.num_audio_tracks + 1) * self.TRACK_HEIGHT return self.TIMESCALE_HEIGHT + video_tracks_height + self.AUDIO_TRACKS_SEPARATOR_Y + audio_tracks_height + 20 - def draw_headers(self, painter): + def draw_headers(self, painter, h_offset): painter.save() painter.setPen(QColor("#AAA")) header_font = QFont("Arial", 9, QFont.Weight.Bold) @@ -187,9 +191,9 @@ def draw_headers(self, painter): y_cursor = self.TIMESCALE_HEIGHT rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect, QColor("#3a3a3a")) - painter.drawRect(rect) - self.add_video_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) + painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a")) + painter.drawRect(rect.translated(h_offset, 0)) + self.add_video_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) painter.setFont(button_font) painter.fillRect(self.add_video_track_btn_rect, QColor("#454")) painter.drawText(self.add_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)") @@ -199,13 +203,13 @@ def draw_headers(self, painter): for i in range(self.timeline.num_video_tracks): track_number = self.timeline.num_video_tracks - i rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect, QColor("#444")) - painter.drawRect(rect) + painter.fillRect(rect.translated(h_offset, 0), QColor("#444")) + painter.drawRect(rect.translated(h_offset, 0)) painter.setFont(header_font) - painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Video {track_number}") + painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Video {track_number}") if track_number == self.timeline.num_video_tracks and self.timeline.num_video_tracks > 1: - self.remove_video_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20) + self.remove_video_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20) painter.setFont(button_font) painter.fillRect(self.remove_video_track_btn_rect, QColor("#833")) painter.drawText(self.remove_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-") @@ -217,22 +221,22 @@ def draw_headers(self, painter): for i in range(self.timeline.num_audio_tracks): track_number = i + 1 rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect, QColor("#444")) - painter.drawRect(rect) + painter.fillRect(rect.translated(h_offset, 0), QColor("#444")) + painter.drawRect(rect.translated(h_offset, 0)) painter.setFont(header_font) - painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}") + painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}") if track_number == self.timeline.num_audio_tracks and self.timeline.num_audio_tracks > 1: - self.remove_audio_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20) + self.remove_audio_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20) painter.setFont(button_font) painter.fillRect(self.remove_audio_track_btn_rect, QColor("#833")) painter.drawText(self.remove_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-") y_cursor += self.TRACK_HEIGHT rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect, QColor("#3a3a3a")) - painter.drawRect(rect) - self.add_audio_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) + painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a")) + painter.drawRect(rect.translated(h_offset, 0)) + self.add_audio_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) painter.setFont(button_font) painter.fillRect(self.add_audio_track_btn_rect, QColor("#454")) painter.drawText(self.add_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)") @@ -477,15 +481,14 @@ def wheelEvent(self, event: QMouseEvent): event.accept() def mousePressEvent(self, event: QMouseEvent): - if event.button() == Qt.MouseButton.LeftButton: - if event.pos().x() < self.HEADER_WIDTH: - # Click in headers area - if self.add_video_track_btn_rect.contains(event.pos()): self.add_track.emit('video') - elif self.remove_video_track_btn_rect.contains(event.pos()): self.remove_track.emit('video') - elif self.add_audio_track_btn_rect.contains(event.pos()): self.add_track.emit('audio') - elif self.remove_audio_track_btn_rect.contains(event.pos()): self.remove_track.emit('audio') - return + if event.pos().x() < self.HEADER_WIDTH + self.scroll_area.horizontalScrollBar().value(): + if self.add_video_track_btn_rect.contains(event.pos()): self.add_track.emit('video') + elif self.remove_video_track_btn_rect.contains(event.pos()): self.remove_track.emit('video') + elif self.add_audio_track_btn_rect.contains(event.pos()): self.add_track.emit('audio') + elif self.remove_audio_track_btn_rect.contains(event.pos()): self.remove_track.emit('audio') + return + if event.button() == Qt.MouseButton.LeftButton: self.dragging_clip = None self.dragging_linked_clip = None self.dragging_playhead = False diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index 56b94efe6..a7c7cd648 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -94,7 +94,6 @@ def set_previews(self, start_pixmap, end_pixmap): def start_polling(self): self.poll_timer.start() - self.check_server_status() def stop_polling(self): self.poll_timer.stop() @@ -111,9 +110,11 @@ def check_server_status(self): try: requests.get(f"{API_BASE_URL}/api/latest_output", timeout=1) self.status_label.setText("Status: Connected to WanGP server.") + return True except requests.exceptions.ConnectionError: self.status_label.setText("Status: Error - Cannot connect to WanGP server.") QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.") + return False def generate(self): @@ -139,9 +140,9 @@ def generate(self): if response.status_code == 200: if payload['start_generation']: self.status_label.setText("Status: Parameters set. Generation sent. Polling...") - self.start_polling() else: - self.status_label.setText("Status: Parameters set. Waiting for manual start.") + self.status_label.setText("Status: Parameters set. Polling for manually started generation...") + self.start_polling() else: self.handle_api_error(response, "setting parameters") except requests.exceptions.RequestException as e: @@ -216,9 +217,8 @@ def setup_generator_for_region(self, region): self._reset_state() self.active_region = region start_sec, end_sec = region - - self.client_widget.check_server_status() - + if not self.client_widget.check_server_status(): + return start_data, w, h = self.app.get_frame_data_at_time(start_sec) end_data, _, _ = self.app.get_frame_data_at_time(end_sec) From 14fbfaf37b1d221407219aab14c75c0e1110d1c9 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 00:10:08 +1100 Subject: [PATCH 099/155] fix FPS conversions --- main.py | 24 +++++---------------- videoeditor/plugins/ai_frame_joiner/main.py | 8 ++++--- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/main.py b/main.py index 06292b315..0dff699e6 100644 --- a/main.py +++ b/main.py @@ -355,7 +355,6 @@ def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_ if model_type: self._api_set_model(model_type) - # 2. Set frame inputs if start_frame: self.widgets['mode_s'].setChecked(True) self.widgets['image_start'].setText(start_frame) @@ -363,37 +362,24 @@ def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_ if end_frame: self.widgets['image_end_checkbox'].setChecked(True) self.widgets['image_end'].setText(end_frame) - - # 3. Calculate video length in frames based on duration and model FPS + if duration_sec is not None: try: duration = float(duration_sec) - - # Get base FPS by parsing the "Force FPS" dropdown's default text (e.g., "Model Default (16 fps)") - base_fps = 16 # Fallback + base_fps = 16 fps_text = self.widgets['force_fps'].itemText(0) match = re.search(r'\((\d+)\s*fps\)', fps_text) if match: base_fps = int(match.group(1)) - # Temporal upsampling creates more frames in post-processing, so we must account for it here. - upsample_setting = self.widgets['temporal_upsampling'].currentData() - multiplier = 1.0 - if upsample_setting == "rife2": - multiplier = 2.0 - elif upsample_setting == "rife4": - multiplier = 4.0 - - # The number of frames the model needs to generate - video_length_frames = int(duration * base_fps * multiplier) + video_length_frames = int(duration * base_fps) self.widgets['video_length'].setValue(video_length_frames) - print(f"API: Calculated video length: {video_length_frames} frames for {duration:.2f}s @ {base_fps*multiplier:.0f} effective FPS.") + print(f"API: Calculated video length: {video_length_frames} frames for {duration:.2f}s @ {base_fps} effective FPS.") except (ValueError, TypeError) as e: print(f"API Error: Invalid duration_sec '{duration_sec}': {e}") - - # 4. Conditionally start generation + if start_generation: self.generate_btn.click() print("API: Generation started.") diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index a7c7cd648..1b20a3b36 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -272,10 +272,12 @@ def insert_generated_clip(self, video_path): return start_sec, end_sec = self.active_region - duration = end_sec - start_sec self.app.status_label.setText(f"Inserting AI clip: {os.path.basename(video_path)}") try: + probe = ffmpeg.probe(video_path) + actual_duration = float(probe['format']['duration']) + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) @@ -291,7 +293,7 @@ def insert_generated_clip(self, video_path): source_path=video_path, timeline_start_sec=start_sec, clip_start_sec=0, - duration_sec=duration, + duration_sec=actual_duration, video_track_index=1, audio_track_index=None ) @@ -301,7 +303,7 @@ def insert_generated_clip(self, video_path): self.app.prune_empty_tracks() except Exception as e: - error_message = f"Error during clip insertion: {e}" + error_message = f"Error during clip insertion/probing: {e}" self.app.status_label.setText(error_message) self.client_widget.status_label.setText("Status: Failed to insert clip.") print(error_message) From e776c35a7b97b7a6ac93eda633f59c31f2eb2c4e Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 00:15:46 +1100 Subject: [PATCH 100/155] fix missing import --- videoeditor/plugins/ai_frame_joiner/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index 1b20a3b36..c81b98fa3 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -4,6 +4,7 @@ import shutil import requests from pathlib import Path +import ffmpeg from PyQt6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QFormLayout, From 77ca09d3eec6efe522244e0876626c1aa1f13a2f Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 00:50:18 +1100 Subject: [PATCH 101/155] add ability to generate into new track, move statuses to main editor --- videoeditor/plugins/ai_frame_joiner/main.py | 117 +++++++++++--------- 1 file changed, 66 insertions(+), 51 deletions(-) diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index c81b98fa3..3c8ef6157 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -21,6 +21,7 @@ class WgpClientWidget(QWidget): generation_complete = pyqtSignal(str) + status_updated = pyqtSignal(str) def __init__(self): super().__init__() @@ -60,10 +61,6 @@ def __init__(self): self.autostart_checkbox.setChecked(True) self.generate_button = QPushButton("Generate") - - self.status_label = QLabel("Status: Idle") - self.output_label = QLabel("Latest Output File: None") - self.output_label.setWordWrap(True) layout.addLayout(previews_layout) form_layout.addRow("Model Type:", self.model_input) @@ -71,8 +68,7 @@ def __init__(self): form_layout.addRow(self.generate_button) layout.addLayout(form_layout) - layout.addWidget(self.status_label) - layout.addWidget(self.output_label) + layout.addStretch() self.generate_button.clicked.connect(self.generate) @@ -104,16 +100,16 @@ def handle_api_error(self, response, action="performing action"): error_msg = response.json().get("error", "Unknown error") except requests.exceptions.JSONDecodeError: error_msg = response.text - self.status_label.setText(f"Status: Error {action}: {error_msg}") + self.status_updated.emit(f"AI Joiner Error: {error_msg}") QMessageBox.warning(self, "API Error", f"Failed while {action}.\n\nServer response:\n{error_msg}") def check_server_status(self): try: requests.get(f"{API_BASE_URL}/api/latest_output", timeout=1) - self.status_label.setText("Status: Connected to WanGP server.") + self.status_updated.emit("AI Joiner: Connected to WanGP server.") return True except requests.exceptions.ConnectionError: - self.status_label.setText("Status: Error - Cannot connect to WanGP server.") + self.status_updated.emit("AI Joiner Error: Cannot connect to WanGP server.") QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.") return False @@ -135,19 +131,19 @@ def generate(self): payload['start_generation'] = self.autostart_checkbox.isChecked() - self.status_label.setText("Status: Sending parameters...") + self.status_updated.emit("AI Joiner: Sending parameters...") try: response = requests.post(f"{API_BASE_URL}/api/generate", json=payload) if response.status_code == 200: if payload['start_generation']: - self.status_label.setText("Status: Parameters set. Generation sent. Polling...") + self.status_updated.emit("AI Joiner: Generation sent to server. Polling for output...") else: - self.status_label.setText("Status: Parameters set. Polling for manually started generation...") + self.status_updated.emit("AI Joiner: Parameters set. Polling for manually started generation...") self.start_polling() else: self.handle_api_error(response, "setting parameters") except requests.exceptions.RequestException as e: - self.status_label.setText(f"Status: Connection error: {e}") + self.status_updated.emit(f"AI Joiner Error: Connection error: {e}") def poll_for_output(self): try: @@ -159,15 +155,12 @@ def poll_for_output(self): if latest_path and latest_path != self.last_known_output: self.stop_polling() self.last_known_output = latest_path - self.output_label.setText(f"Latest Output File:\n{latest_path}") - self.status_label.setText("Status: New output received! Inserting clip...") + self.status_updated.emit(f"AI Joiner: New output received! Inserting clip...") self.generation_complete.emit(latest_path) else: - if "Error" not in self.status_label.text() and "waiting" not in self.status_label.text().lower(): - self.status_label.setText("Status: Polling for output...") + self.status_updated.emit("AI Joiner: Polling for output...") except requests.exceptions.RequestException: - if "Error" not in self.status_label.text(): - self.status_label.setText("Status: Polling... (Connection issue)") + self.status_updated.emit("AI Joiner: Polling... (Connection issue)") class Plugin(VideoEditorPlugin): def initialize(self): @@ -177,7 +170,9 @@ def initialize(self): self.dock_widget = None self.active_region = None self.temp_dir = None + self.insert_on_new_track = False self.client_widget.generation_complete.connect(self.insert_generated_clip) + self.client_widget.status_updated.connect(self.update_main_status) def enable(self): if not self.dock_widget: @@ -195,6 +190,9 @@ def disable(self): self._cleanup_temp_dir() self.client_widget.stop_polling() + def update_main_status(self, message): + self.app.status_label.setText(message) + def _cleanup_temp_dir(self): if self.temp_dir and os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir) @@ -202,9 +200,9 @@ def _cleanup_temp_dir(self): def _reset_state(self): self.active_region = None + self.insert_on_new_track = False self._cleanup_temp_dir() - self.client_widget.status_label.setText("Status: Idle") - self.client_widget.output_label.setText("Latest Output File: None") + self.update_main_status("AI Joiner: Idle") self.client_widget.set_previews(None, None) def on_timeline_context_menu(self, menu, event): @@ -212,14 +210,20 @@ def on_timeline_context_menu(self, menu, event): if region: menu.addSeparator() action = menu.addAction("Join Frames With AI") - action.triggered.connect(lambda: self.setup_generator_for_region(region)) + action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) + + action_new_track = menu.addAction("Join Frames With AI (New Track)") + action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) - def setup_generator_for_region(self, region): + def setup_generator_for_region(self, region, on_new_track=False): self._reset_state() self.active_region = region + self.insert_on_new_track = on_new_track + start_sec, end_sec = region if not self.client_widget.check_server_status(): return + start_data, w, h = self.app.get_frame_data_at_time(start_sec) end_data, _, _ = self.app.get_frame_data_at_time(end_sec) @@ -258,55 +262,66 @@ def setup_generator_for_region(self, region): self.client_widget.start_frame_input.setText(start_img_path) self.client_widget.end_frame_input.setText(end_img_path) - self.client_widget.status_label.setText(f"Status: Ready for region {start_sec:.2f}s - {end_sec:.2f}s") + self.update_main_status(f"AI Joiner: Ready for region {start_sec:.2f}s - {end_sec:.2f}s") self.dock_widget.show() self.dock_widget.raise_() def insert_generated_clip(self, video_path): if not self.active_region: - self.client_widget.status_label.setText("Status: Error - No active region to insert into.") + self.update_main_status("AI Joiner Error: No active region to insert into.") return if not os.path.exists(video_path): - self.client_widget.status_label.setText(f"Status: Error - Output file not found: {video_path}") + self.update_main_status(f"AI Joiner Error: Output file not found: {video_path}") return start_sec, end_sec = self.active_region - self.app.status_label.setText(f"Inserting AI clip: {os.path.basename(video_path)}") + self.update_main_status(f"AI Joiner: Inserting clip {os.path.basename(video_path)}") try: probe = ffmpeg.probe(video_path) actual_duration = float(probe['format']['duration']) - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) + if self.insert_on_new_track: + self.app.add_track('video') + new_track_index = self.app.timeline.num_video_tracks + self.app._add_clip_to_timeline( + source_path=video_path, + timeline_start_sec=start_sec, + clip_start_sec=0, + duration_sec=actual_duration, + video_track_index=new_track_index, + audio_track_index=None + ) + self.update_main_status("AI clip inserted successfully on new track.") + else: + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) + + clips_to_remove = [ + c for c in self.app.timeline.clips + if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec + ] + for clip in clips_to_remove: + if clip in self.app.timeline.clips: + self.app.timeline.clips.remove(clip) + + self.app._add_clip_to_timeline( + source_path=video_path, + timeline_start_sec=start_sec, + clip_start_sec=0, + duration_sec=actual_duration, + video_track_index=1, + audio_track_index=None + ) + self.update_main_status("AI clip inserted successfully.") - clips_to_remove = [ - c for c in self.app.timeline.clips - if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec - ] - for clip in clips_to_remove: - if clip in self.app.timeline.clips: - self.app.timeline.clips.remove(clip) - - self.app._add_clip_to_timeline( - source_path=video_path, - timeline_start_sec=start_sec, - clip_start_sec=0, - duration_sec=actual_duration, - video_track_index=1, - audio_track_index=None - ) - - self.app.status_label.setText("AI clip inserted successfully.") - self.client_widget.status_label.setText("Status: Success! Clip inserted.") self.app.prune_empty_tracks() except Exception as e: - error_message = f"Error during clip insertion/probing: {e}" - self.app.status_label.setText(error_message) - self.client_widget.status_label.setText("Status: Failed to insert clip.") + error_message = f"AI Joiner Error during clip insertion/probing: {e}" + self.update_main_status(error_message) print(error_message) finally: From 3438e8fd833392c0b6e13770685ac511f1a45b94 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 01:30:07 +1100 Subject: [PATCH 102/155] add image and audio media types --- videoeditor/main.py | 291 ++++++++++++++------ videoeditor/plugins/ai_frame_joiner/main.py | 2 + 2 files changed, 209 insertions(+), 84 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 0c08e7b9a..49ac27b13 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -19,7 +19,7 @@ from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand class TimelineClip: - def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, group_id): + def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, media_type, group_id): self.id = str(uuid.uuid4()) self.source_path = source_path self.timeline_start_sec = timeline_start_sec @@ -27,6 +27,7 @@ def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec self.duration_sec = duration_sec self.track_index = track_index self.track_type = track_type + self.media_type = media_type self.group_id = group_id @property @@ -136,7 +137,6 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.drag_over_active = False self.drag_over_rect = QRectF() self.drag_over_audio_rect = QRectF() - self.drag_has_audio = False def set_project_fps(self, fps): self.project_fps = fps if fps > 0 else 25.0 @@ -163,12 +163,12 @@ def paintEvent(self, event): self.draw_selections(painter) if self.drag_over_active: - painter.fillRect(self.drag_over_rect, QColor(0, 255, 0, 80)) - if self.drag_has_audio: - painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80)) painter.setPen(QColor(0, 255, 0, 150)) - painter.drawRect(self.drag_over_rect) - if self.drag_has_audio: + if not self.drag_over_rect.isNull(): + painter.fillRect(self.drag_over_rect, QColor(0, 255, 0, 80)) + painter.drawRect(self.drag_over_rect) + if not self.drag_over_audio_rect.isNull(): + painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80)) painter.drawRect(self.drag_over_audio_rect) self.draw_playhead(painter) @@ -376,7 +376,12 @@ def draw_tracks_and_clips(self, painter): for clip in self.timeline.clips: clip_rect = self.get_clip_rect(clip) - base_color = QColor("#46A") if clip.track_type == 'video' else QColor("#48C") + base_color = QColor("#46A") # Default video + if clip.media_type == 'image': + base_color = QColor("#4A6") # Greenish for images + elif clip.track_type == 'audio': + base_color = QColor("#48C") # Bluish for audio + color = QColor("#5A9") if self.dragging_clip and self.dragging_clip.id == clip.id else base_color painter.fillRect(clip_rect, color) painter.setPen(QPen(QColor("#FFF"), 1)) @@ -667,47 +672,55 @@ def dragMoveEvent(self, event): json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8')) duration = json_data['duration'] - self.drag_has_audio = json_data['has_audio'] + media_type = json_data['media_type'] + has_audio = json_data['has_audio'] pos = event.position() - track_info = self.y_to_track_info(pos.y()) start_sec = self.x_to_sec(pos.x()) + track_info = self.y_to_track_info(pos.y()) + + self.drag_over_rect = QRectF() + self.drag_over_audio_rect = QRectF() + self.drag_over_active = False + self.highlighted_ghost_track_info = None + self.highlighted_track_info = None if track_info: self.drag_over_active = True track_type, track_index = track_info - + is_ghost_track = (track_type == 'video' and track_index > self.timeline.num_video_tracks) or \ (track_type == 'audio' and track_index > self.timeline.num_audio_tracks) if is_ghost_track: self.highlighted_ghost_track_info = track_info - self.highlighted_track_info = None else: - self.highlighted_ghost_track_info = None self.highlighted_track_info = track_info width = int(duration * self.pixels_per_second) x = self.sec_to_x(start_sec) - if track_type == 'video': - visual_index = self.timeline.num_video_tracks - track_index - y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT - self.drag_over_rect = QRectF(x, y, width, self.TRACK_HEIGHT) - if self.drag_has_audio: - audio_y = self.audio_tracks_y_start # Default to track 1 - self.drag_over_audio_rect = QRectF(x, audio_y, width, self.TRACK_HEIGHT) + video_y, audio_y = -1, -1 + + if media_type in ['video', 'image']: + if track_type == 'video': + visual_index = self.timeline.num_video_tracks - track_index + video_y = self.video_tracks_y_start + visual_index * self.TRACK_HEIGHT + if has_audio: + audio_y = self.audio_tracks_y_start + elif track_type == 'audio' and has_audio: + visual_index = track_index - 1 + audio_y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT + video_y = self.video_tracks_y_start + (self.timeline.num_video_tracks - 1) * self.TRACK_HEIGHT + + elif media_type == 'audio': + if track_type == 'audio': + visual_index = track_index - 1 + audio_y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT - elif track_type == 'audio': - visual_index = track_index - 1 - y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT - self.drag_over_audio_rect = QRectF(x, y, width, self.TRACK_HEIGHT) - # For now, just show video on track 1 if dragging onto audio - video_y = self.video_tracks_y_start + (self.timeline.num_video_tracks - 1) * self.TRACK_HEIGHT + if video_y != -1: self.drag_over_rect = QRectF(x, video_y, width, self.TRACK_HEIGHT) - else: - self.drag_over_active = False - self.highlighted_ghost_track_info = None - self.highlighted_track_info = None + if audio_y != -1: + self.drag_over_audio_rect = QRectF(x, audio_y, width, self.TRACK_HEIGHT) self.update() @@ -725,6 +738,7 @@ def dropEvent(self, event): file_path = json_data['path'] duration = json_data['duration'] has_audio = json_data['has_audio'] + media_type = json_data['media_type'] pos = event.position() start_sec = self.x_to_sec(pos.x()) @@ -734,19 +748,32 @@ def dropEvent(self, event): return drop_track_type, drop_track_index = track_info - video_track_idx = 1 - audio_track_idx = 1 if has_audio else None - - if drop_track_type == 'video': - video_track_idx = drop_track_index - elif drop_track_type == 'audio': - audio_track_idx = drop_track_index + video_track_idx = None + audio_track_idx = None + + if media_type == 'image': + if drop_track_type == 'video': + video_track_idx = drop_track_index + elif media_type == 'audio': + if drop_track_type == 'audio': + audio_track_idx = drop_track_index + elif media_type == 'video': + if drop_track_type == 'video': + video_track_idx = drop_track_index + if has_audio: audio_track_idx = 1 + elif drop_track_type == 'audio' and has_audio: + audio_track_idx = drop_track_index + video_track_idx = 1 + + if video_track_idx is None and audio_track_idx is None: + return main_window = self.window() main_window._add_clip_to_timeline( source_path=file_path, timeline_start_sec=start_sec, duration_sec=duration, + media_type=media_type, clip_start_sec=0, video_track_index=video_track_idx, audio_track_index=audio_track_idx @@ -847,7 +874,8 @@ def startDrag(self, supportedActions): payload = { "path": path, "duration": media_info['duration'], - "has_audio": media_info['has_audio'] + "has_audio": media_info['has_audio'], + "media_type": media_info['media_type'] } mime_data.setData('application/x-vnd.video.filepath', QByteArray(json.dumps(payload).encode('utf-8'))) @@ -1049,7 +1077,7 @@ def _connect_signals(self): self.frame_forward_button.clicked.connect(lambda: self.step_frame(1)) self.playback_timer.timeout.connect(self.advance_playback_frame) - self.project_media_widget.add_media_requested.connect(self.add_video_clip) + self.project_media_widget.add_media_requested.connect(self.add_media_files) self.project_media_widget.media_removed.connect(self.on_media_removed_from_pool) self.project_media_widget.add_to_timeline_requested.connect(self.on_add_to_timeline_at_playhead) @@ -1095,14 +1123,19 @@ def on_add_to_timeline_at_playhead(self, file_path): playhead_pos = self.timeline_widget.playhead_pos_sec duration = media_info['duration'] has_audio = media_info['has_audio'] + media_type = media_info['media_type'] + + video_track = 1 if media_type in ['video', 'image'] else None + audio_track = 1 if has_audio else None self._add_clip_to_timeline( source_path=file_path, timeline_start_sec=playhead_pos, duration_sec=duration, + media_type=media_type, clip_start_sec=0.0, - video_track_index=1, - audio_track_index=1 if has_audio else None + video_track_index=video_track, + audio_track_index=audio_track ) def prune_empty_tracks(self): @@ -1165,13 +1198,13 @@ def _create_menu_bar(self): open_action = QAction("&Open Project...", self); open_action.triggered.connect(self.open_project) self.recent_menu = file_menu.addMenu("Recent") save_action = QAction("&Save Project As...", self); save_action.triggered.connect(self.save_project_as) - add_video_action = QAction("&Add Media to Project...", self); add_video_action.triggered.connect(self.add_video_clip) + add_media_action = QAction("&Add Media to Project...", self); add_media_action.triggered.connect(self.add_media_files) export_action = QAction("&Export Video...", self); export_action.triggered.connect(self.export_video) settings_action = QAction("Se&ttings...", self); settings_action.triggered.connect(self.open_settings_dialog) exit_action = QAction("E&xit", self); exit_action.triggered.connect(self.close) file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addSeparator() file_menu.addAction(save_action) - file_menu.addSeparator(); file_menu.addAction(add_video_action); file_menu.addAction(export_action) + file_menu.addSeparator(); file_menu.addAction(add_media_action); file_menu.addAction(export_action) file_menu.addSeparator(); file_menu.addAction(settings_action); file_menu.addSeparator(); file_menu.addAction(exit_action) self._update_recent_files_menu() @@ -1252,13 +1285,24 @@ def get_frame_data_at_time(self, time_sec): if not clip_at_time: return (None, 0, 0) try: - clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec - out, _ = ( - ffmpeg - .input(clip_at_time.source_path, ss=clip_time) - .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') - .run(capture_stdout=True, quiet=True) - ) + w, h = self.project_width, self.project_height + if clip_at_time.media_type == 'image': + out, _ = ( + ffmpeg + .input(clip_at_time.source_path) + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + else: + clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec + out, _ = ( + ffmpeg + .input(clip_at_time.source_path, ss=clip_time) + .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) return (out, self.project_width, self.project_height) except ffmpeg.Error as e: print(f"Error extracting frame data: {e.stderr}") @@ -1269,8 +1313,19 @@ def get_frame_at_time(self, time_sec): black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) if not clip_at_time: return black_pixmap try: - clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec - out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time).output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24').run(capture_stdout=True, quiet=True)) + w, h = self.project_width, self.project_height + if clip_at_time.media_type == 'image': + out, _ = ( + ffmpeg.input(clip_at_time.source_path) + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True) + ) + else: + clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec + out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time).output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24').run(capture_stdout=True, quiet=True)) + image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888) return QPixmap.fromImage(image) except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap @@ -1303,14 +1358,32 @@ def advance_playback_frame(self): frame_duration = 1.0 / self.project_fps new_time = self.timeline_widget.playhead_pos_sec + frame_duration if new_time > self.timeline.get_total_duration(): self.stop_playback(); return - self.timeline_widget.playhead_pos_sec = new_time; self.timeline_widget.update() + + self.timeline_widget.playhead_pos_sec = new_time + self.timeline_widget.update() + clip_at_new_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec), None) + if not clip_at_new_time: self._stop_playback_stream() black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - self.preview_widget.setPixmap(scaled_pixmap); return - if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: self._start_playback_stream_at(new_time) + self.preview_widget.setPixmap(scaled_pixmap) + return + + if clip_at_new_time.media_type == 'image': + if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: + self._stop_playback_stream() + self.playback_clip = clip_at_new_time + frame_pixmap = self.get_frame_at_time(new_time) + if frame_pixmap: + scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + self.preview_widget.setPixmap(scaled_pixmap) + return + + if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: + self._start_playback_stream_at(new_time) + if self.playback_process: frame_size = self.project_width * self.project_height * 3 frame_bytes = self.playback_process.stdout.read(frame_size) @@ -1319,7 +1392,8 @@ def advance_playback_frame(self): pixmap = QPixmap.fromImage(image) scaled_pixmap = pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) self.preview_widget.setPixmap(scaled_pixmap) - else: self._stop_playback_stream() + else: + self._stop_playback_stream() def _load_settings(self): self.settings_file_was_loaded = False @@ -1386,7 +1460,7 @@ def save_project_as(self): if not path: return project_data = { "media_pool": self.media_pool, - "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "group_id": c.group_id} for c in self.timeline.clips], + "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips], "settings": {"num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks} } try: @@ -1410,6 +1484,14 @@ def _load_project_from_path(self, path): for clip_data in project_data["clips"]: if not os.path.exists(clip_data["source_path"]): self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.new_project(); return + + if 'media_type' not in clip_data: + ext = os.path.splitext(clip_data['source_path'])[1].lower() + if ext in ['.mp3', '.wav', '.m4a', '.aac']: + clip_data['media_type'] = 'audio' + else: + clip_data['media_type'] = 'video' + self.timeline.add_clip(TimelineClip(**clip_data)) project_settings = project_data.get("settings", {}) @@ -1417,7 +1499,9 @@ def _load_project_from_path(self, path): self.timeline.num_audio_tracks = project_settings.get("num_audio_tracks", 1) self.current_project_path = path - if self.timeline.clips: self._set_project_properties_from_clip(self.timeline.clips[0].source_path) + video_clips = [c for c in self.timeline.clips if c.media_type == 'video'] + if video_clips: + self._set_project_properties_from_clip(video_clips[0].source_path) self.prune_empty_tracks() self.timeline_widget.update(); self.stop_playback() self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.") @@ -1445,14 +1529,31 @@ def _add_media_to_pool(self, file_path): if file_path in self.media_pool: return True try: self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents() - probe = ffmpeg.probe(file_path) - video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) - if not video_stream: raise ValueError("No video stream found.") - duration = float(video_stream.get('duration', probe['format'].get('duration', 0))) - has_audio = any(s['codec_type'] == 'audio' for s in probe['streams']) + file_ext = os.path.splitext(file_path)[1].lower() + media_info = {} - self.media_properties[file_path] = {'duration': duration, 'has_audio': has_audio} + if file_ext in ['.png', '.jpg', '.jpeg']: + media_info['media_type'] = 'image' + media_info['duration'] = 5.0 + media_info['has_audio'] = False + else: + probe = ffmpeg.probe(file_path) + video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) + audio_stream = next((s for s in probe['streams'] if s['codec_type'] == 'audio'), None) + + if video_stream: + media_info['media_type'] = 'video' + media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0))) + media_info['has_audio'] = audio_stream is not None + elif audio_stream: + media_info['media_type'] = 'audio' + media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0))) + media_info['has_audio'] = True + else: + raise ValueError("No video or audio stream found.") + + self.media_properties[file_path] = media_info self.media_pool.append(file_path) self.project_media_widget.add_media_item(file_path) @@ -1477,33 +1578,42 @@ def on_media_removed_from_pool(self, file_path): self.undo_stack.push(command) - def add_video_clip(self): - file_path, _ = QFileDialog.getOpenFileName(self, "Open Video File", "", "Video Files (*.mp4 *.mov *.avi)") - if not file_path: return + def add_media_files(self): + file_paths, _ = QFileDialog.getOpenFileNames(self, "Open Media Files", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)") + if not file_paths: return + self.media_dock.show() - if not self.timeline.clips and not self.media_pool: - if not self._set_project_properties_from_clip(file_path): - self.status_label.setText("Error: Could not determine video properties from file."); return - - if self._add_media_to_pool(file_path): - self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.") - else: - self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.") + first_video_added = any(c.media_type == 'video' for c in self.timeline.clips) + + for file_path in file_paths: + if not self.timeline.clips and not self.media_pool and not first_video_added: + ext = os.path.splitext(file_path)[1].lower() + if ext not in ['.png', '.jpg', '.jpeg', '.mp3', '.wav']: + if self._set_project_properties_from_clip(file_path): + first_video_added = True + else: + self.status_label.setText("Error: Could not determine video properties from file."); + continue + + if self._add_media_to_pool(file_path): + self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.") + else: + self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.") - def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, clip_start_sec=0.0, video_track_index=None, audio_track_index=None): + def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None): old_state = self._get_current_timeline_state() group_id = str(uuid.uuid4()) if video_track_index is not None: if video_track_index > self.timeline.num_video_tracks: self.timeline.num_video_tracks = video_track_index - video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', group_id) + video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', media_type, group_id) self.timeline.add_clip(video_clip) if audio_track_index is not None: if audio_track_index > self.timeline.num_audio_tracks: self.timeline.num_audio_tracks = audio_track_index - audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', group_id) + audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', media_type, group_id) self.timeline.add_clip(audio_clip) new_state = self._get_current_timeline_state() @@ -1517,7 +1627,7 @@ def _split_at_time(self, clip_to_split, time_sec, new_group_id=None): orig_dur = clip_to_split.duration_sec group_id_for_new_clip = new_group_id if new_group_id is not None else clip_to_split.group_id - new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, group_id_for_new_clip) + new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, clip_to_split.media_type, group_id_for_new_clip) clip_to_split.duration_sec = split_point self.timeline.add_clip(new_clip) return True @@ -1707,8 +1817,8 @@ def export_video(self): w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration() sample_rate, channel_layout = '44100', 'stereo' - input_files = {clip.source_path: ffmpeg.input(clip.source_path) for clip in self.timeline.clips} - + input_streams = {} + last_video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi') for i in range(self.timeline.num_video_tracks): track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec) @@ -1722,8 +1832,18 @@ def export_video(self): if gap > 0.01: track_segments.append(ffmpeg.input(f'color=c=black@0.0:s={w}x{h}:r={fr_str}:d={gap}', f='lavfi').filter('format', pix_fmts='rgba')) - v_seg = (input_files[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS') - .filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba')) + if clip.source_path not in input_streams: + if clip.media_type == 'image': + input_streams[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps) + else: + input_streams[clip.source_path] = ffmpeg.input(clip.source_path) + + if clip.media_type == 'image': + v_seg = (input_streams[clip.source_path].video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS')) + else: # video + v_seg = (input_streams[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS')) + + v_seg = (v_seg.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba')) track_segments.append(v_seg) last_end = clip.timeline_end_sec @@ -1740,11 +1860,14 @@ def export_video(self): last_end = track_clips[0].timeline_start_sec for clip in track_clips: + if clip.source_path not in input_streams: + input_streams[clip.source_path] = ffmpeg.input(clip.source_path) + gap = clip.timeline_start_sec - last_end if gap > 0.01: track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi')) - a_seg = input_files[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS') + a_seg = input_streams[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS') track_segments.append(a_seg) last_end = clip.timeline_end_sec diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index 3c8ef6157..d66dda353 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -291,6 +291,7 @@ def insert_generated_clip(self, video_path): timeline_start_sec=start_sec, clip_start_sec=0, duration_sec=actual_duration, + media_type='video', video_track_index=new_track_index, audio_track_index=None ) @@ -312,6 +313,7 @@ def insert_generated_clip(self, video_path): timeline_start_sec=start_sec, clip_start_sec=0, duration_sec=actual_duration, + media_type='video', video_track_index=1, audio_track_index=None ) From a5b51dfb1207ebd64f04f96e61d569029aa23294 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 02:10:56 +1100 Subject: [PATCH 103/155] add ability to resize clips --- videoeditor/main.py | 123 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 120 insertions(+), 3 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 49ac27b13..3ce1bccde 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -11,7 +11,7 @@ QScrollArea, QFrame, QProgressBar, QDialog, QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox) from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction, - QPixmap, QImage, QDrag) + QPixmap, QImage, QDrag, QCursor) from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread, pyqtSignal, QTimer, QByteArray, QMimeData) @@ -82,6 +82,7 @@ class TimelineWidget(QWidget): HEADER_WIDTH = 120 TRACK_HEIGHT = 50 AUDIO_TRACKS_SEPARATOR_Y = 15 + RESIZE_HANDLE_WIDTH = 8 split_requested = pyqtSignal(object) delete_clip_requested = pyqtSignal(object) @@ -124,6 +125,10 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.drag_selection_start_values = None self.drag_start_state = None + self.resizing_clip = None + self.resize_edge = None # 'left' or 'right' + self.resize_start_pos = QPoint() + self.highlighted_track_info = None self.highlighted_ghost_track_info = None self.add_video_track_btn_rect = QRect() @@ -499,8 +504,28 @@ def mousePressEvent(self, event: QMouseEvent): self.dragging_playhead = False self.dragging_selection_region = None self.creating_selection_region = False + self.resizing_clip = None + self.resize_edge = None self.drag_original_clip_states.clear() + # Check for resize handles first + for clip in reversed(self.timeline.clips): + clip_rect = self.get_clip_rect(clip) + if abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y())): + self.resizing_clip = clip + self.resize_edge = 'left' + break + elif abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y())): + self.resizing_clip = clip + self.resize_edge = 'right' + break + + if self.resizing_clip: + self.drag_start_state = self.window()._get_current_timeline_state() + self.resize_start_pos = event.pos() + self.update() + return + for clip in reversed(self.timeline.clips): clip_rect = self.get_clip_rect(clip) if clip_rect.contains(QPointF(event.pos())): @@ -538,6 +563,77 @@ def mousePressEvent(self, event: QMouseEvent): self.update() def mouseMoveEvent(self, event: QMouseEvent): + if self.resizing_clip: + delta_x = event.pos().x() - self.resize_start_pos.x() + time_delta = delta_x / self.pixels_per_second + min_duration = 1.0 / self.project_fps + + media_props = self.window().media_properties.get(self.resizing_clip.source_path) + source_duration = media_props['duration'] if media_props else float('inf') + + if self.resize_edge == 'left': + original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec + original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec + original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec + + new_start_sec = original_start + time_delta + + # Clamp to not move past the end of the clip + if new_start_sec > original_start + original_duration - min_duration: + new_start_sec = original_start + original_duration - min_duration + + # Clamp to zero + new_start_sec = max(0, new_start_sec) + + # For video/audio, don't allow extending beyond the source start + if self.resizing_clip.media_type != 'image': + if new_start_sec < original_start - original_clip_start: + new_start_sec = original_start - original_clip_start + + new_duration = (original_start + original_duration) - new_start_sec + # FIX: Correctly calculate the new clip start by ADDING the time shift + new_clip_start = original_clip_start + (new_start_sec - original_start) + + if new_duration < min_duration: + new_duration = min_duration + new_start_sec = (original_start + original_duration) - new_duration + # FIX: Apply the same corrected logic here after clamping duration + new_clip_start = original_clip_start + (new_start_sec - original_start) + + self.resizing_clip.timeline_start_sec = new_start_sec + self.resizing_clip.duration_sec = new_duration + self.resizing_clip.clip_start_sec = new_clip_start + + elif self.resize_edge == 'right': + original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec + new_duration = original_duration + time_delta + + if new_duration < min_duration: + new_duration = min_duration + + # For video/audio, don't allow extending beyond the source end + if self.resizing_clip.media_type != 'image': + if self.resizing_clip.clip_start_sec + new_duration > source_duration: + new_duration = source_duration - self.resizing_clip.clip_start_sec + + self.resizing_clip.duration_sec = new_duration + + self.update() + return + + # Update cursor for resizing + if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region: + cursor_set = False + for clip in self.timeline.clips: + clip_rect = self.get_clip_rect(clip) + if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \ + (abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y()))): + self.setCursor(Qt.CursorShape.SizeHorCursor) + cursor_set = True + break + if not cursor_set: + self.unsetCursor() + if self.creating_selection_region: current_sec = self.x_to_sec(event.pos().x()) start = min(self.selection_drag_start_sec, current_sec) @@ -613,6 +709,17 @@ def mouseMoveEvent(self, event: QMouseEvent): def mouseReleaseEvent(self, event: QMouseEvent): if event.button() == Qt.MouseButton.LeftButton: + if self.resizing_clip: + new_state = self.window()._get_current_timeline_state() + command = TimelineStateChangeCommand("Resize Clip", self.timeline, *self.drag_start_state, *new_state) + command.undo() + self.window().undo_stack.push(command) + self.resizing_clip = None + self.resize_edge = None + self.drag_start_state = None + self.update() + return + if self.creating_selection_region: self.creating_selection_region = False if self.selection_regions: @@ -1249,8 +1356,12 @@ def _start_playback_stream_at(self, time_sec): if not clip: return self.playback_clip = clip clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec + w, h = self.project_width, self.project_height try: - args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time).output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile()) + args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time) + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile()) self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) except Exception as e: print(f"Failed to start playback stream: {e}"); self._stop_playback_stream() @@ -1300,6 +1411,8 @@ def get_frame_data_at_time(self, time_sec): out, _ = ( ffmpeg .input(clip_at_time.source_path, ss=clip_time) + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') .run(capture_stdout=True, quiet=True) ) @@ -1324,7 +1437,11 @@ def get_frame_at_time(self, time_sec): ) else: clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec - out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time).output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24').run(capture_stdout=True, quiet=True)) + out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time) + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True)) image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888) return QPixmap.fromImage(image) From 25bcc33827d44f00070ab945fac9715e550cf228 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 02:37:49 +1100 Subject: [PATCH 104/155] add playhead snapping for all functions --- videoeditor/main.py | 87 +++++++++++++++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 26 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 3ce1bccde..b3d3828e4 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -83,6 +83,7 @@ class TimelineWidget(QWidget): TRACK_HEIGHT = 50 AUDIO_TRACKS_SEPARATOR_Y = 15 RESIZE_HANDLE_WIDTH = 8 + SNAP_THRESHOLD_PIXELS = 8 split_requested = pyqtSignal(object) delete_clip_requested = pyqtSignal(object) @@ -553,7 +554,11 @@ def mousePressEvent(self, event: QMouseEvent): if is_in_track_area: self.creating_selection_region = True - self.selection_drag_start_sec = self.x_to_sec(event.pos().x()) + playhead_x = self.sec_to_x(self.playhead_pos_sec) + if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS: + self.selection_drag_start_sec = self.playhead_pos_sec + else: + self.selection_drag_start_sec = self.x_to_sec(event.pos().x()) self.selection_regions.append([self.selection_drag_start_sec, self.selection_drag_start_sec]) elif is_on_timescale: self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x())) @@ -567,6 +572,8 @@ def mouseMoveEvent(self, event: QMouseEvent): delta_x = event.pos().x() - self.resize_start_pos.x() time_delta = delta_x / self.pixels_per_second min_duration = 1.0 / self.project_fps + playhead_time = self.playhead_pos_sec + snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second media_props = self.window().media_properties.get(self.resizing_clip.source_path) source_duration = media_props['duration'] if media_props else float('inf') @@ -575,29 +582,27 @@ def mouseMoveEvent(self, event: QMouseEvent): original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec - - new_start_sec = original_start + time_delta - - # Clamp to not move past the end of the clip + true_new_start_sec = original_start + time_delta + if abs(true_new_start_sec - playhead_time) < snap_time_delta: + new_start_sec = playhead_time + else: + new_start_sec = true_new_start_sec + if new_start_sec > original_start + original_duration - min_duration: new_start_sec = original_start + original_duration - min_duration - - # Clamp to zero + new_start_sec = max(0, new_start_sec) - # For video/audio, don't allow extending beyond the source start if self.resizing_clip.media_type != 'image': if new_start_sec < original_start - original_clip_start: new_start_sec = original_start - original_clip_start new_duration = (original_start + original_duration) - new_start_sec - # FIX: Correctly calculate the new clip start by ADDING the time shift new_clip_start = original_clip_start + (new_start_sec - original_start) if new_duration < min_duration: new_duration = min_duration new_start_sec = (original_start + original_duration) - new_duration - # FIX: Apply the same corrected logic here after clamping duration new_clip_start = original_clip_start + (new_start_sec - original_start) self.resizing_clip.timeline_start_sec = new_start_sec @@ -605,13 +610,20 @@ def mouseMoveEvent(self, event: QMouseEvent): self.resizing_clip.clip_start_sec = new_clip_start elif self.resize_edge == 'right': + original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec - new_duration = original_duration + time_delta + + true_new_duration = original_duration + time_delta + true_new_end_time = original_start + true_new_duration + + if abs(true_new_end_time - playhead_time) < snap_time_delta: + new_duration = playhead_time - original_start + else: + new_duration = true_new_duration if new_duration < min_duration: new_duration = min_duration - - # For video/audio, don't allow extending beyond the source end + if self.resizing_clip.media_type != 'image': if self.resizing_clip.clip_start_sec + new_duration > source_duration: new_duration = source_duration - self.resizing_clip.clip_start_sec @@ -620,17 +632,26 @@ def mouseMoveEvent(self, event: QMouseEvent): self.update() return - - # Update cursor for resizing + if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region: cursor_set = False - for clip in self.timeline.clips: - clip_rect = self.get_clip_rect(clip) - if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \ - (abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y()))): - self.setCursor(Qt.CursorShape.SizeHorCursor) - cursor_set = True - break + ### ADD BEGIN ### + # Check for playhead proximity for selection snapping + playhead_x = self.sec_to_x(self.playhead_pos_sec) + is_in_track_area = event.pos().y() > self.TIMESCALE_HEIGHT and event.pos().x() > self.HEADER_WIDTH + if is_in_track_area and abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS: + self.setCursor(Qt.CursorShape.SizeHorCursor) + cursor_set = True + + if not cursor_set: + ### ADD END ### + for clip in self.timeline.clips: + clip_rect = self.get_clip_rect(clip) + if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \ + (abs(event.pos().x() - clip_rect.right()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.right(), event.pos().y()))): + self.setCursor(Qt.CursorShape.SizeHorCursor) + cursor_set = True + break if not cursor_set: self.unsetCursor() @@ -683,7 +704,18 @@ def mouseMoveEvent(self, event: QMouseEvent): delta_x = event.pos().x() - self.drag_start_pos.x() time_delta = delta_x / self.pixels_per_second - new_start_time = original_start_sec + time_delta + true_new_start_time = original_start_sec + time_delta + + playhead_time = self.playhead_pos_sec + snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second + + new_start_time = true_new_start_time + true_new_end_time = true_new_start_time + self.dragging_clip.duration_sec + + if abs(true_new_start_time - playhead_time) < snap_time_delta: + new_start_time = playhead_time + elif abs(true_new_end_time - playhead_time) < snap_time_delta: + new_start_time = playhead_time - self.dragging_clip.duration_sec for other_clip in self.timeline.clips: if other_clip.id == self.dragging_clip.id: continue @@ -696,8 +728,11 @@ def mouseMoveEvent(self, event: QMouseEvent): new_start_time + self.dragging_clip.duration_sec > other_clip.timeline_start_sec) if is_overlapping: - if time_delta > 0: new_start_time = other_clip.timeline_start_sec - self.dragging_clip.duration_sec - else: new_start_time = other_clip.timeline_end_sec + movement_direction = true_new_start_time - original_start_sec + if movement_direction > 0: + new_start_time = other_clip.timeline_start_sec - self.dragging_clip.duration_sec + else: + new_start_time = other_clip.timeline_end_sec break final_start_time = max(0, new_start_time) @@ -725,7 +760,7 @@ def mouseReleaseEvent(self, event: QMouseEvent): if self.selection_regions: start, end = self.selection_regions[-1] if (end - start) * self.pixels_per_second < 2: - self.selection_regions.pop() + self.clear_all_regions() if self.dragging_selection_region: self.dragging_selection_region = None From dc68e8682ad27a1d4e728dee4f34d76ef9096cc5 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 03:11:58 +1100 Subject: [PATCH 105/155] added feature to unlink and relink audio tracks --- videoeditor/main.py | 88 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 85 insertions(+), 3 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index b3d3828e4..8d9e27d30 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -569,6 +569,7 @@ def mousePressEvent(self, event: QMouseEvent): def mouseMoveEvent(self, event: QMouseEvent): if self.resizing_clip: + linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None) delta_x = event.pos().x() - self.resize_start_pos.x() time_delta = delta_x / self.pixels_per_second min_duration = 1.0 / self.project_fps @@ -608,6 +609,10 @@ def mouseMoveEvent(self, event: QMouseEvent): self.resizing_clip.timeline_start_sec = new_start_sec self.resizing_clip.duration_sec = new_duration self.resizing_clip.clip_start_sec = new_clip_start + if linked_clip: + linked_clip.timeline_start_sec = new_start_sec + linked_clip.duration_sec = new_duration + linked_clip.clip_start_sec = new_clip_start elif self.resize_edge == 'right': original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec @@ -629,14 +634,14 @@ def mouseMoveEvent(self, event: QMouseEvent): new_duration = source_duration - self.resizing_clip.clip_start_sec self.resizing_clip.duration_sec = new_duration + if linked_clip: + linked_clip.duration_sec = new_duration self.update() return if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region: cursor_set = False - ### ADD BEGIN ### - # Check for playhead proximity for selection snapping playhead_x = self.sec_to_x(self.playhead_pos_sec) is_in_track_area = event.pos().y() > self.TIMESCALE_HEIGHT and event.pos().x() > self.HEADER_WIDTH if is_in_track_area and abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS: @@ -644,7 +649,6 @@ def mouseMoveEvent(self, event: QMouseEvent): cursor_set = True if not cursor_set: - ### ADD END ### for clip in self.timeline.clips: clip_rect = self.get_clip_rect(clip) if (abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y()))) or \ @@ -952,6 +956,18 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'): if clip_at_pos: if not menu.isEmpty(): menu.addSeparator() + + linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_at_pos.group_id and c.id != clip_at_pos.id), None) + if linked_clip: + unlink_action = menu.addAction("Unlink Audio Track") + unlink_action.triggered.connect(lambda: self.window().unlink_clip_pair(clip_at_pos)) + else: + media_info = self.window().media_properties.get(clip_at_pos.source_path) + if (clip_at_pos.track_type == 'video' and + media_info and media_info.get('has_audio')): + relink_action = menu.addAction("Relink Audio Track") + relink_action.triggered.connect(lambda: self.window().relink_clip_audio(clip_at_pos)) + split_action = menu.addAction("Split Clip") delete_action = menu.addAction("Delete Clip") playhead_time = self.playhead_pos_sec @@ -1825,6 +1841,72 @@ def delete_clip(self, clip_to_delete): self.undo_stack.push(command) self.prune_empty_tracks() + def unlink_clip_pair(self, clip_to_unlink): + old_state = self._get_current_timeline_state() + + linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_unlink.group_id and c.id != clip_to_unlink.id), None) + + if linked_clip: + clip_to_unlink.group_id = str(uuid.uuid4()) + linked_clip.group_id = str(uuid.uuid4()) + + new_state = self._get_current_timeline_state() + command = TimelineStateChangeCommand("Unlink Clips", self.timeline, *old_state, *new_state) + command.undo() + self.undo_stack.push(command) + self.status_label.setText("Clips unlinked.") + else: + self.status_label.setText("Could not find a clip to unlink.") + + def relink_clip_audio(self, video_clip): + def action(): + media_info = self.media_properties.get(video_clip.source_path) + if not media_info or not media_info.get('has_audio'): + self.status_label.setText("Source media has no audio to relink.") + return + + target_audio_track = 1 + new_audio_start = video_clip.timeline_start_sec + new_audio_end = video_clip.timeline_end_sec + + conflicting_clips = [ + c for c in self.timeline.clips + if c.track_type == 'audio' and c.track_index == target_audio_track and + c.timeline_start_sec < new_audio_end and c.timeline_end_sec > new_audio_start + ] + + for conflict_clip in conflicting_clips: + found_spot = False + for check_track_idx in range(target_audio_track + 1, self.timeline.num_audio_tracks + 2): + is_occupied = any( + other.timeline_start_sec < conflict_clip.timeline_end_sec and other.timeline_end_sec > conflict_clip.timeline_start_sec + for other in self.timeline.clips + if other.id != conflict_clip.id and other.track_type == 'audio' and other.track_index == check_track_idx + ) + + if not is_occupied: + if check_track_idx > self.timeline.num_audio_tracks: + self.timeline.num_audio_tracks = check_track_idx + + conflict_clip.track_index = check_track_idx + found_spot = True + break + + new_audio_clip = TimelineClip( + source_path=video_clip.source_path, + timeline_start_sec=video_clip.timeline_start_sec, + clip_start_sec=video_clip.clip_start_sec, + duration_sec=video_clip.duration_sec, + track_index=target_audio_track, + track_type='audio', + media_type=video_clip.media_type, + group_id=video_clip.group_id + ) + self.timeline.add_clip(new_audio_clip) + self.status_label.setText("Audio relinked.") + + self._perform_complex_timeline_change("Relink Audio", action) + def _perform_complex_timeline_change(self, description, change_function): old_state = self._get_current_timeline_state() change_function() From bea0f48c70276e6472ba0a57fbb3953d9854e869 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 03:43:35 +1100 Subject: [PATCH 106/155] add ability to add media directly to timeline --- videoeditor/main.py | 134 +++++++++++++++++++++++++++++++------------- 1 file changed, 96 insertions(+), 38 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 8d9e27d30..c4d3e9268 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -107,7 +107,7 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.scroll_area = None self.pixels_per_second = 50.0 - self.max_pixels_per_second = 1.0 # Will be updated by set_project_fps + self.max_pixels_per_second = 1.0 self.project_fps = 25.0 self.set_project_fps(project_fps) @@ -121,13 +121,13 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.creating_selection_region = False self.dragging_selection_region = None self.drag_start_pos = QPoint() - self.drag_original_clip_states = {} # Store {'clip_id': (start_sec, track_index)} + self.drag_original_clip_states = {} self.selection_drag_start_sec = 0.0 self.drag_selection_start_values = None self.drag_start_state = None self.resizing_clip = None - self.resize_edge = None # 'left' or 'right' + self.resize_edge = None self.resize_start_pos = QPoint() self.highlighted_track_info = None @@ -146,7 +146,6 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): def set_project_fps(self, fps): self.project_fps = fps if fps > 0 else 25.0 - # Set max zoom to be 10 pixels per frame self.max_pixels_per_second = self.project_fps * 20 self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second) self.update() @@ -263,7 +262,6 @@ def _format_timecode(self, seconds): precision = 2 if s < 1 else 1 if s < 10 else 0 val = f"{s:.{precision}f}" - # Remove trailing .0 or .00 if '.' in val: val = val.rstrip('0').rstrip('.') return f"{sign}{val}s" @@ -382,11 +380,11 @@ def draw_tracks_and_clips(self, painter): for clip in self.timeline.clips: clip_rect = self.get_clip_rect(clip) - base_color = QColor("#46A") # Default video + base_color = QColor("#46A") if clip.media_type == 'image': - base_color = QColor("#4A6") # Greenish for images + base_color = QColor("#4A6") elif clip.track_type == 'audio': - base_color = QColor("#48C") # Bluish for audio + base_color = QColor("#48C") color = QColor("#5A9") if self.dragging_clip and self.dragging_clip.id == clip.id else base_color painter.fillRect(clip_rect, color) @@ -415,25 +413,21 @@ def draw_playhead(self, painter): painter.drawLine(playhead_x, 0, playhead_x, self.height()) def y_to_track_info(self, y): - # Check for ghost video track if self.TIMESCALE_HEIGHT <= y < self.video_tracks_y_start: return ('video', self.timeline.num_video_tracks + 1) - # Check for existing video tracks video_tracks_end_y = self.video_tracks_y_start + self.timeline.num_video_tracks * self.TRACK_HEIGHT if self.video_tracks_y_start <= y < video_tracks_end_y: visual_index = (y - self.video_tracks_y_start) // self.TRACK_HEIGHT track_index = self.timeline.num_video_tracks - visual_index return ('video', track_index) - - # Check for existing audio tracks + audio_tracks_end_y = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT if self.audio_tracks_y_start <= y < audio_tracks_end_y: visual_index = (y - self.audio_tracks_y_start) // self.TRACK_HEIGHT track_index = visual_index + 1 return ('audio', track_index) - - # Check for ghost audio track + add_audio_btn_y_start = self.audio_tracks_y_start + self.timeline.num_audio_tracks * self.TRACK_HEIGHT add_audio_btn_y_end = add_audio_btn_y_start + self.TRACK_HEIGHT if add_audio_btn_y_start <= y < add_audio_btn_y_end: @@ -468,27 +462,19 @@ def wheelEvent(self, event: QMouseEvent): if delta > 0: new_pps = old_pps * zoom_factor else: new_pps = old_pps / zoom_factor - min_pps = 1 / (3600 * 10) # Min zoom of 1px per 10 hours + min_pps = 1 / (3600 * 10) new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second)) if abs(new_pps - old_pps) < 1e-9: return self.pixels_per_second = new_pps - # This will trigger a paintEvent, which resizes the widget. - # The scroll area will update its scrollbar ranges in response. self.update() - - # The new absolute x-coordinate for the time under the cursor + new_mouse_x_abs = self.sec_to_x(time_at_cursor) - - # The amount the content "shifted" at the cursor's location shift_amount = new_mouse_x_abs - mouse_x_abs - - # Adjust the scrollbar by this shift amount to keep the content under the cursor new_scroll_value = scrollbar.value() + shift_amount scrollbar.setValue(int(new_scroll_value)) - event.accept() def mousePressEvent(self, event: QMouseEvent): @@ -509,7 +495,6 @@ def mousePressEvent(self, event: QMouseEvent): self.resize_edge = None self.drag_original_clip_states.clear() - # Check for resize handles first for clip in reversed(self.timeline.clips): clip_rect = self.get_clip_rect(clip) if abs(event.pos().x() - clip_rect.left()) < self.RESIZE_HANDLE_WIDTH and clip_rect.contains(QPointF(clip_rect.left(), event.pos().y())): @@ -1356,13 +1341,18 @@ def _create_menu_bar(self): open_action = QAction("&Open Project...", self); open_action.triggered.connect(self.open_project) self.recent_menu = file_menu.addMenu("Recent") save_action = QAction("&Save Project As...", self); save_action.triggered.connect(self.save_project_as) + add_media_to_timeline_action = QAction("Add Media to &Timeline...", self) + add_media_to_timeline_action.triggered.connect(self.add_media_to_timeline) add_media_action = QAction("&Add Media to Project...", self); add_media_action.triggered.connect(self.add_media_files) export_action = QAction("&Export Video...", self); export_action.triggered.connect(self.export_video) settings_action = QAction("Se&ttings...", self); settings_action.triggered.connect(self.open_settings_dialog) exit_action = QAction("E&xit", self); exit_action.triggered.connect(self.close) file_menu.addAction(new_action); file_menu.addAction(open_action); file_menu.addSeparator() file_menu.addAction(save_action) - file_menu.addSeparator(); file_menu.addAction(add_media_action); file_menu.addAction(export_action) + file_menu.addSeparator() + file_menu.addAction(add_media_to_timeline_action) + file_menu.addAction(add_media_action) + file_menu.addAction(export_action) file_menu.addSeparator(); file_menu.addAction(settings_action); file_menu.addSeparator(); file_menu.addAction(exit_action) self._update_recent_files_menu() @@ -1565,7 +1555,7 @@ def advance_playback_frame(self): def _load_settings(self): self.settings_file_was_loaded = False - defaults = {"window_visibility": {"project_media": True}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True} + defaults = {"window_visibility": {"project_media": False}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True} if os.path.exists(self.settings_file): try: with open(self.settings_file, "r") as f: self.settings = json.load(f) @@ -1591,7 +1581,7 @@ def _apply_loaded_settings(self): visibility_settings = self.settings.get("window_visibility", {}) for key, data in self.managed_widgets.items(): if data.get('plugin'): continue - is_visible = visibility_settings.get(key, True) + is_visible = visibility_settings.get(key, False if key == 'project_media' else True) if data['widget'] is not self.preview_widget: data['widget'].setVisible(is_visible) if data['action']: data['action'].setChecked(is_visible) @@ -1745,12 +1735,13 @@ def on_media_removed_from_pool(self, file_path): command.undo() self.undo_stack.push(command) + def _add_media_files_to_project(self, file_paths): + if not file_paths: + return [] - def add_media_files(self): - file_paths, _ = QFileDialog.getOpenFileNames(self, "Open Media Files", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)") - if not file_paths: return self.media_dock.show() + added_files = [] first_video_added = any(c.media_type == 'video' for c in self.timeline.clips) for file_path in file_paths: @@ -1760,13 +1751,81 @@ def add_media_files(self): if self._set_project_properties_from_clip(file_path): first_video_added = True else: - self.status_label.setText("Error: Could not determine video properties from file."); + self.status_label.setText("Error: Could not determine video properties from file.") continue if self._add_media_to_pool(file_path): - self.status_label.setText(f"Added {os.path.basename(file_path)} to project media.") - else: - self.status_label.setText(f"Failed to add {os.path.basename(file_path)}.") + added_files.append(file_path) + + return added_files + + def add_media_to_timeline(self): + file_paths, _ = QFileDialog.getOpenFileNames(self, "Add Media to Timeline", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)") + if not file_paths: + return + + added_files = self._add_media_files_to_project(file_paths) + if not added_files: + return + + playhead_pos = self.timeline_widget.playhead_pos_sec + + def add_clips_action(): + for file_path in added_files: + media_info = self.media_properties.get(file_path) + if not media_info: continue + + duration = media_info['duration'] + media_type = media_info['media_type'] + has_audio = media_info['has_audio'] + + clip_start_time = playhead_pos + clip_end_time = playhead_pos + duration + + video_track_index = None + audio_track_index = None + + if media_type in ['video', 'image']: + for i in range(1, self.timeline.num_video_tracks + 2): + is_occupied = any( + c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time + for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i + ) + if not is_occupied: + video_track_index = i + break + + if has_audio: + for i in range(1, self.timeline.num_audio_tracks + 2): + is_occupied = any( + c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time + for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + ) + if not is_occupied: + audio_track_index = i + break + + group_id = str(uuid.uuid4()) + if video_track_index is not None: + if video_track_index > self.timeline.num_video_tracks: + self.timeline.num_video_tracks = video_track_index + video_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, video_track_index, 'video', media_type, group_id) + self.timeline.add_clip(video_clip) + + if audio_track_index is not None: + if audio_track_index > self.timeline.num_audio_tracks: + self.timeline.num_audio_tracks = audio_track_index + audio_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, audio_track_index, 'audio', media_type, group_id) + self.timeline.add_clip(audio_clip) + + self.status_label.setText(f"Added {len(added_files)} file(s) to timeline.") + + self._perform_complex_timeline_change("Add Media to Timeline", add_clips_action) + + def add_media_files(self): + file_paths, _ = QFileDialog.getOpenFileNames(self, "Open Media Files", "", "All Supported Files (*.mp4 *.mov *.avi *.png *.jpg *.jpeg *.mp3 *.wav);;Video Files (*.mp4 *.mov *.avi);;Image Files (*.png *.jpg *.jpeg);;Audio Files (*.mp3 *.wav)") + if file_paths: + self._add_media_files_to_project(file_paths) def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None): old_state = self._get_current_timeline_state() @@ -2032,8 +2091,7 @@ def action(): for clip in clips_to_remove: try: self.timeline.clips.remove(clip) except ValueError: pass - - # Ripple + for clip in self.timeline.clips: if clip.timeline_start_sec >= end_sec: clip.timeline_start_sec -= duration_to_remove @@ -2074,7 +2132,7 @@ def export_video(self): if clip.media_type == 'image': v_seg = (input_streams[clip.source_path].video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS')) - else: # video + else: v_seg = (input_streams[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS')) v_seg = (v_seg.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba')) From afc567a09cd0efb6fdf79e6bc935339880741674 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 04:04:40 +1100 Subject: [PATCH 107/155] add ability to drag and drop from OS into timeline or project media --- videoeditor/main.py | 170 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 162 insertions(+), 8 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index c4d3e9268..0ff19f672 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -143,6 +143,7 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.drag_over_active = False self.drag_over_rect = QRectF() self.drag_over_audio_rect = QRectF() + self.drag_url_cache = {} def set_project_fps(self, fps): self.project_fps = fps if fps > 0 else 25.0 @@ -782,7 +783,9 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.update() def dragEnterEvent(self, event): - if event.mimeData().hasFormat('application/x-vnd.video.filepath'): + if event.mimeData().hasFormat('application/x-vnd.video.filepath') or event.mimeData().hasUrls(): + if event.mimeData().hasUrls(): + self.drag_url_cache.clear() event.acceptProposedAction() else: event.ignore() @@ -791,20 +794,65 @@ def dragLeaveEvent(self, event): self.drag_over_active = False self.highlighted_ghost_track_info = None self.highlighted_track_info = None + self.drag_url_cache.clear() self.update() def dragMoveEvent(self, event): mime_data = event.mimeData() - if not mime_data.hasFormat('application/x-vnd.video.filepath'): - event.ignore() - return + media_props = None + if mime_data.hasUrls(): + urls = mime_data.urls() + if not urls: + event.ignore() + return + + file_path = urls[0].toLocalFile() + + if file_path in self.drag_url_cache: + media_props = self.drag_url_cache[file_path] + else: + probed_props = self.window()._probe_for_drag(file_path) + if probed_props: + self.drag_url_cache[file_path] = probed_props + media_props = probed_props + + elif mime_data.hasFormat('application/x-vnd.video.filepath'): + json_data_bytes = mime_data.data('application/x-vnd.video.filepath').data() + media_props = json.loads(json_data_bytes.decode('utf-8')) + + if not media_props: + if mime_data.hasUrls(): + event.acceptProposedAction() + pos = event.position() + track_info = self.y_to_track_info(pos.y()) + + self.drag_over_rect = QRectF() + self.drag_over_audio_rect = QRectF() + self.drag_over_active = False + self.highlighted_ghost_track_info = None + self.highlighted_track_info = None + + if track_info: + self.drag_over_active = True + track_type, track_index = track_info + is_ghost_track = (track_type == 'video' and track_index > self.timeline.num_video_tracks) or \ + (track_type == 'audio' and track_index > self.timeline.num_audio_tracks) + if is_ghost_track: + self.highlighted_ghost_track_info = track_info + else: + self.highlighted_track_info = track_info + + self.update() + else: + event.ignore() + return + event.acceptProposedAction() - json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8')) - duration = json_data['duration'] - media_type = json_data['media_type'] - has_audio = json_data['has_audio'] + duration = media_props['duration'] + media_type = media_props['media_type'] + has_audio = media_props['has_audio'] pos = event.position() start_sec = self.x_to_sec(pos.x()) @@ -859,9 +907,68 @@ def dropEvent(self, event): self.drag_over_active = False self.highlighted_ghost_track_info = None self.highlighted_track_info = None + self.drag_url_cache.clear() self.update() mime_data = event.mimeData() + if mime_data.hasUrls(): + file_paths = [url.toLocalFile() for url in mime_data.urls()] + main_window = self.window() + added_files = main_window._add_media_files_to_project(file_paths) + if not added_files: + event.ignore() + return + + pos = event.position() + start_sec = self.x_to_sec(pos.x()) + track_info = self.y_to_track_info(pos.y()) + if not track_info: + event.ignore() + return + + current_timeline_pos = start_sec + + for file_path in added_files: + media_info = main_window.media_properties.get(file_path) + if not media_info: continue + + duration = media_info['duration'] + has_audio = media_info['has_audio'] + media_type = media_info['media_type'] + + drop_track_type, drop_track_index = track_info + video_track_idx = None + audio_track_idx = None + + if media_type == 'image': + if drop_track_type == 'video': video_track_idx = drop_track_index + elif media_type == 'audio': + if drop_track_type == 'audio': audio_track_idx = drop_track_index + elif media_type == 'video': + if drop_track_type == 'video': + video_track_idx = drop_track_index + if has_audio: audio_track_idx = 1 + elif drop_track_type == 'audio' and has_audio: + audio_track_idx = drop_track_index + video_track_idx = 1 + + if video_track_idx is None and audio_track_idx is None: + continue + + main_window._add_clip_to_timeline( + source_path=file_path, + timeline_start_sec=current_timeline_pos, + duration_sec=duration, + media_type=media_type, + clip_start_sec=0, + video_track_index=video_track_idx, + audio_track_index=audio_track_idx + ) + current_timeline_pos += duration + + event.acceptProposedAction() + return + if not mime_data.hasFormat('application/x-vnd.video.filepath'): return @@ -1033,6 +1140,7 @@ class ProjectMediaWidget(QWidget): def __init__(self, parent=None): super().__init__(parent) self.main_window = parent + self.setAcceptDrops(True) layout = QVBoxLayout(self) layout.setContentsMargins(2, 2, 2, 2) @@ -1050,6 +1158,21 @@ def __init__(self, parent=None): add_button.clicked.connect(self.add_media_requested.emit) remove_button.clicked.connect(self.remove_selected_media) + + def dragEnterEvent(self, event): + if event.mimeData().hasUrls(): + event.acceptProposedAction() + else: + event.ignore() + + def dropEvent(self, event): + if event.mimeData().hasUrls(): + file_paths = [url.toLocalFile() for url in event.mimeData().urls()] + self.main_window._add_media_files_to_project(file_paths) + event.acceptProposedAction() + else: + event.ignore() + def show_context_menu(self, pos): item = self.media_list.itemAt(pos) if not item: @@ -1432,6 +1555,37 @@ def _set_project_properties_from_clip(self, source_path): except Exception as e: print(f"Could not probe for project properties: {e}") return False + def _probe_for_drag(self, file_path): + if file_path in self.media_properties: + return self.media_properties[file_path] + try: + file_ext = os.path.splitext(file_path)[1].lower() + media_info = {} + + if file_ext in ['.png', '.jpg', '.jpeg']: + media_info['media_type'] = 'image' + media_info['duration'] = 5.0 + media_info['has_audio'] = False + else: + probe = ffmpeg.probe(file_path) + video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) + audio_stream = next((s for s in probe['streams'] if s['codec_type'] == 'audio'), None) + + if video_stream: + media_info['media_type'] = 'video' + media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0))) + media_info['has_audio'] = audio_stream is not None + elif audio_stream: + media_info['media_type'] = 'audio' + media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0))) + media_info['has_audio'] = True + else: + return None + return media_info + except Exception as e: + print(f"Failed to probe dragged file {os.path.basename(file_path)}: {e}") + return None + def get_frame_data_at_time(self, time_sec): clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) if not clip_at_time: From 17567341dc0e5b4b4c77573ccb0176a7ef00b16a Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 04:31:33 +1100 Subject: [PATCH 108/155] fix ai joiner not inserting on new track with new stream structure --- videoeditor/main.py | 82 +++++++++++++++++++-------------------------- 1 file changed, 34 insertions(+), 48 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 0ff19f672..e9981dc7f 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -1539,25 +1539,11 @@ def _stop_playback_stream(self): self.playback_process = None self.playback_clip = None - def _set_project_properties_from_clip(self, source_path): - try: - probe = ffmpeg.probe(source_path) - video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) - if video_stream: - self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height']) - if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0': - num, den = map(int, video_stream['r_frame_rate'].split('/')) - if den > 0: - self.project_fps = num / den - self.timeline_widget.set_project_fps(self.project_fps) - print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS") - return True - except Exception as e: print(f"Could not probe for project properties: {e}") - return False - - def _probe_for_drag(self, file_path): + def _get_media_properties(self, file_path): + """Probes a file to get its media properties. Returns a dict or None.""" if file_path in self.media_properties: return self.media_properties[file_path] + try: file_ext = os.path.splitext(file_path)[1].lower() media_info = {} @@ -1581,11 +1567,31 @@ def _probe_for_drag(self, file_path): media_info['has_audio'] = True else: return None + return media_info except Exception as e: - print(f"Failed to probe dragged file {os.path.basename(file_path)}: {e}") + print(f"Failed to probe file {os.path.basename(file_path)}: {e}") return None + def _set_project_properties_from_clip(self, source_path): + try: + probe = ffmpeg.probe(source_path) + video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) + if video_stream: + self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height']) + if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0': + num, den = map(int, video_stream['r_frame_rate'].split('/')) + if den > 0: + self.project_fps = num / den + self.timeline_widget.set_project_fps(self.project_fps) + print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS") + return True + except Exception as e: print(f"Could not probe for project properties: {e}") + return False + + def _probe_for_drag(self, file_path): + return self._get_media_properties(file_path) + def get_frame_data_at_time(self, time_sec): clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) if not clip_at_time: @@ -1838,41 +1844,21 @@ def _update_recent_files_menu(self): self.recent_menu.addAction(action) def _add_media_to_pool(self, file_path): - if file_path in self.media_pool: return True - try: - self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents() - - file_ext = os.path.splitext(file_path)[1].lower() - media_info = {} - - if file_ext in ['.png', '.jpg', '.jpeg']: - media_info['media_type'] = 'image' - media_info['duration'] = 5.0 - media_info['has_audio'] = False - else: - probe = ffmpeg.probe(file_path) - video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) - audio_stream = next((s for s in probe['streams'] if s['codec_type'] == 'audio'), None) - - if video_stream: - media_info['media_type'] = 'video' - media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0))) - media_info['has_audio'] = audio_stream is not None - elif audio_stream: - media_info['media_type'] = 'audio' - media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0))) - media_info['has_audio'] = True - else: - raise ValueError("No video or audio stream found.") - + if file_path in self.media_pool: + return True + + self.status_label.setText(f"Probing {os.path.basename(file_path)}..."); QApplication.processEvents() + + media_info = self._get_media_properties(file_path) + + if media_info: self.media_properties[file_path] = media_info self.media_pool.append(file_path) self.project_media_widget.add_media_item(file_path) - self.status_label.setText(f"Added {os.path.basename(file_path)} to project.") return True - except Exception as e: - self.status_label.setText(f"Error probing file: {e}") + else: + self.status_label.setText(f"Error probing file: {os.path.basename(file_path)}") return False def on_media_removed_from_pool(self, file_path): From a223c1e8df4d0104a84965db09ae789c35123da1 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 04:32:01 +1100 Subject: [PATCH 109/155] fix ai joiner not inserting on new track with new stream structure (2) --- videoeditor/plugins/ai_frame_joiner/main.py | 34 ++++++++++++++------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index d66dda353..85fd7a4d3 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -5,6 +5,7 @@ import requests from pathlib import Path import ffmpeg +import uuid from PyQt6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QFormLayout, @@ -268,6 +269,8 @@ def setup_generator_for_region(self, region, on_new_track=False): self.dock_widget.raise_() def insert_generated_clip(self, video_path): + from main import TimelineClip + if not self.active_region: self.update_main_status("AI Joiner Error: No active region to insert into.") return @@ -277,25 +280,29 @@ def insert_generated_clip(self, video_path): return start_sec, end_sec = self.active_region + is_new_track_mode = self.insert_on_new_track + self.update_main_status(f"AI Joiner: Inserting clip {os.path.basename(video_path)}") - try: + def complex_insertion_action(): probe = ffmpeg.probe(video_path) actual_duration = float(probe['format']['duration']) - if self.insert_on_new_track: - self.app.add_track('video') + if is_new_track_mode: + self.app.timeline.num_video_tracks += 1 new_track_index = self.app.timeline.num_video_tracks - self.app._add_clip_to_timeline( + + new_clip = TimelineClip( source_path=video_path, timeline_start_sec=start_sec, clip_start_sec=0, duration_sec=actual_duration, + track_index=new_track_index, + track_type='video', media_type='video', - video_track_index=new_track_index, - audio_track_index=None + group_id=str(uuid.uuid4()) ) - self.update_main_status("AI clip inserted successfully on new track.") + self.app.timeline.add_clip(new_clip) else: for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) @@ -308,18 +315,23 @@ def insert_generated_clip(self, video_path): if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip) - self.app._add_clip_to_timeline( + new_clip = TimelineClip( source_path=video_path, timeline_start_sec=start_sec, clip_start_sec=0, duration_sec=actual_duration, + track_index=1, + track_type='video', media_type='video', - video_track_index=1, - audio_track_index=None + group_id=str(uuid.uuid4()) ) - self.update_main_status("AI clip inserted successfully.") + self.app.timeline.add_clip(new_clip) + try: + self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action) + self.app.prune_empty_tracks() + self.update_main_status("AI clip inserted successfully.") except Exception as e: error_message = f"AI Joiner Error during clip insertion/probing: {e}" From 8472a8a92a1b7cff5168f5df9f0873dca2a3703e Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 06:05:18 +1100 Subject: [PATCH 110/155] add Create mode for AI plusin --- main.py | 46 ++-- videoeditor/main.py | 17 +- videoeditor/plugins/ai_frame_joiner/main.py | 239 +++++++++++++++++--- 3 files changed, 238 insertions(+), 64 deletions(-) diff --git a/main.py b/main.py index 0dff699e6..773bcc9fa 100644 --- a/main.py +++ b/main.py @@ -312,56 +312,49 @@ def __init__(self): self.apply_initial_config() self.connect_signals() self.init_wgp_state() - - # --- CHANGE: Removed setModelSignal as it's no longer needed --- + self.api_bridge.generateSignal.connect(self._api_generate) @pyqtSlot(str) def _api_set_model(self, model_type): """This slot is executed in the main GUI thread.""" if not model_type or not wgp: return - - # 1. Check if the model is valid by looking it up in the master model definition dictionary. + if model_type not in wgp.models_def: print(f"API Error: Model type '{model_type}' is not a valid model.") return - # 2. Check if already selected to avoid unnecessary UI refreshes. if self.state.get('model_type') == model_type: print(f"API: Model is already set to {model_type}.") return - # 3. Redraw all model dropdowns to ensure the correct hierarchy is displayed - # and the target model is selected. This function handles finding the - # correct Family and Base model for the given finetune model_type. self.update_model_dropdowns(model_type) - - # 4. Manually trigger the logic that normally runs when the user selects a model. - # This is necessary because update_model_dropdowns blocks signals. + self._on_model_changed() - # 5. Final check to see if the model was actually set. if self.state.get('model_type') == model_type: print(f"API: Successfully set model to {model_type}.") else: - # This could happen if update_model_dropdowns silently fails to find the model. print(f"API Error: Failed to set model to '{model_type}'. The model might be hidden by your current configuration.") - - # --- CHANGE: Slot now accepts model_type and duration_sec, and calculates frame count --- + @pyqtSlot(object, object, object, object, bool) def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_generation): """This slot is executed in the main GUI thread.""" - # 1. Set model if a new one is provided if model_type: self._api_set_model(model_type) - if start_frame: self.widgets['mode_s'].setChecked(True) self.widgets['image_start'].setText(start_frame) + else: + self.widgets['mode_t'].setChecked(True) + self.widgets['image_start'].clear() if end_frame: self.widgets['image_end_checkbox'].setChecked(True) self.widgets['image_end'].setText(end_frame) + else: + self.widgets['image_end_checkbox'].setChecked(False) + self.widgets['image_end'].clear() if duration_sec is not None: try: @@ -1707,7 +1700,7 @@ def _add_task_to_queue_and_update_ui(self): def _add_task_to_queue(self): queue_size_before = len(self.state["gen"]["queue"]) all_inputs = self.collect_inputs() - keys_to_remove = ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality'] + keys_to_remove = ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type'] for key in keys_to_remove: all_inputs.pop(key, None) @@ -1938,7 +1931,6 @@ def run_api_server(): api_server.run(port=5100, host='127.0.0.1', debug=False) if FLASK_AVAILABLE: - # --- CHANGE: Removed /api/set_model endpoint --- @api_server.route('/api/generate', methods=['POST']) def generate(): @@ -1964,13 +1956,13 @@ def generate(): return jsonify({"message": "Parameters set without starting generation."}) - @api_server.route('/api/latest_output', methods=['GET']) - def get_latest_output(): + @api_server.route('/api/outputs', methods=['GET']) + def get_outputs(): if not main_window_instance: return jsonify({"error": "Application not ready"}), 503 - - path = main_window_instance.latest_output_path - return jsonify({"latest_output_path": path}) + + file_list = main_window_instance.state.get('gen', {}).get('file_list', []) + return jsonify({"outputs": file_list}) # ===================================================================== # --- END OF API SERVER ADDITION --- @@ -1978,13 +1970,11 @@ def get_latest_output(): if __name__ == '__main__': app = QApplication(sys.argv) - - # Create and show the main window + window = MainWindow() - main_window_instance = window # Assign to global for API access + main_window_instance = window window.show() - # Start the Flask API server in a separate thread if FLASK_AVAILABLE: api_thread = threading.Thread(target=run_api_server, daemon=True) api_thread.start() diff --git a/videoeditor/main.py b/videoeditor/main.py index e9981dc7f..d1bd5ce8d 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -139,12 +139,18 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.video_tracks_y_start = 0 self.audio_tracks_y_start = 0 - + self.hover_preview_rect = None + self.hover_preview_audio_rect = None self.drag_over_active = False self.drag_over_rect = QRectF() self.drag_over_audio_rect = QRectF() self.drag_url_cache = {} + def set_hover_preview_rects(self, video_rect, audio_rect): + self.hover_preview_rect = video_rect + self.hover_preview_audio_rect = audio_rect + self.update() + def set_project_fps(self, fps): self.project_fps = fps if fps > 0 else 25.0 self.max_pixels_per_second = self.project_fps * 20 @@ -177,6 +183,15 @@ def paintEvent(self, event): painter.fillRect(self.drag_over_audio_rect, QColor(0, 255, 0, 80)) painter.drawRect(self.drag_over_audio_rect) + if self.hover_preview_rect: + painter.setPen(QPen(QColor(0, 255, 255, 180), 2, Qt.PenStyle.DashLine)) + painter.fillRect(self.hover_preview_rect, QColor(0, 255, 255, 60)) + painter.drawRect(self.hover_preview_rect) + if self.hover_preview_audio_rect: + painter.setPen(QPen(QColor(0, 255, 255, 180), 2, Qt.PenStyle.DashLine)) + painter.fillRect(self.hover_preview_audio_rect, QColor(0, 255, 255, 60)) + painter.drawRect(self.hover_preview_audio_rect) + self.draw_playhead(painter) total_width = self.sec_to_x(self.timeline.get_total_duration()) + 200 diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py index 85fd7a4d3..410333d28 100644 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ b/videoeditor/plugins/ai_frame_joiner/main.py @@ -9,33 +9,134 @@ from PyQt6.QtWidgets import ( QWidget, QVBoxLayout, QHBoxLayout, QFormLayout, - QLineEdit, QPushButton, QLabel, QMessageBox, QCheckBox + QLineEdit, QPushButton, QLabel, QMessageBox, QCheckBox, QListWidget, + QListWidgetItem, QGroupBox ) from PyQt6.QtGui import QImage, QPixmap -from PyQt6.QtCore import QTimer, pyqtSignal, Qt, QSize +from PyQt6.QtCore import pyqtSignal, Qt, QSize, QRectF, QUrl, QTimer +from PyQt6.QtMultimedia import QMediaPlayer +from PyQt6.QtMultimediaWidgets import QVideoWidget sys.path.append(str(Path(__file__).parent.parent.parent)) from plugins import VideoEditorPlugin - API_BASE_URL = "http://127.0.0.1:5100" +class VideoResultItemWidget(QWidget): + def __init__(self, video_path, plugin, parent=None): + super().__init__(parent) + self.video_path = video_path + self.plugin = plugin + self.app = plugin.app + self.duration = 0.0 + self.has_audio = False + + self.setMinimumSize(200, 180) + self.setMaximumHeight(190) + + layout = QVBoxLayout(self) + layout.setContentsMargins(5, 5, 5, 5) + layout.setSpacing(5) + + self.media_player = QMediaPlayer() + self.video_widget = QVideoWidget() + self.video_widget.setFixedSize(160, 90) + self.media_player.setVideoOutput(self.video_widget) + self.media_player.setSource(QUrl.fromLocalFile(self.video_path)) + self.media_player.setLoops(QMediaPlayer.Loops.Infinite) + + self.info_label = QLabel(os.path.basename(video_path)) + self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.info_label.setWordWrap(True) + + self.insert_button = QPushButton("Insert into Timeline") + self.insert_button.clicked.connect(self.on_insert) + + h_layout = QHBoxLayout() + h_layout.addStretch() + h_layout.addWidget(self.video_widget) + h_layout.addStretch() + + layout.addLayout(h_layout) + layout.addWidget(self.info_label) + layout.addWidget(self.insert_button) + + self.probe_video() + + def probe_video(self): + try: + probe = ffmpeg.probe(self.video_path) + self.duration = float(probe['format']['duration']) + self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', [])) + + self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)") + + except Exception as e: + self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}") + print(f"Error probing video {self.video_path}: {e}") + + def enterEvent(self, event): + super().enterEvent(event) + self.media_player.play() + + if not self.plugin.active_region or self.duration == 0: + return + + start_sec, _ = self.plugin.active_region + timeline = self.app.timeline_widget + + video_rect, audio_rect = None, None + + x = timeline.sec_to_x(start_sec) + w = int(self.duration * timeline.pixels_per_second) + + if self.plugin.insert_on_new_track: + video_y = timeline.TIMESCALE_HEIGHT + video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) + if self.has_audio: + audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT + audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) + else: + v_track_idx = 1 + visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx + video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT + video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) + if self.has_audio: + a_track_idx = 1 + visual_a_idx = a_track_idx - 1 + audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT + audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) + + timeline.set_hover_preview_rects(video_rect, audio_rect) + + def leaveEvent(self, event): + super().leaveEvent(event) + self.media_player.pause() + self.media_player.setPosition(0) + self.app.timeline_widget.set_hover_preview_rects(None, None) + + def on_insert(self): + self.media_player.stop() + self.media_player.setSource(QUrl()) + self.media_player.setVideoOutput(None) + self.app.timeline_widget.set_hover_preview_rects(None, None) + self.plugin.insert_generated_clip(self.video_path) + class WgpClientWidget(QWidget): - generation_complete = pyqtSignal(str) status_updated = pyqtSignal(str) - def __init__(self): + def __init__(self, plugin): super().__init__() - - self.last_known_output = None + self.plugin = plugin + self.processed_files = set() layout = QVBoxLayout(self) form_layout = QFormLayout() - self.model_input = QLineEdit() - self.model_input.setPlaceholderText("Optional (default: i2v_2_2)") - - previews_layout = QHBoxLayout() + self.previews_widget = QWidget() + previews_layout = QHBoxLayout(self.previews_widget) + previews_layout.setContentsMargins(0, 0, 0, 0) + start_preview_layout = QVBoxLayout() start_preview_layout.addWidget(QLabel("Start Frame")) self.start_frame_preview = QLabel("N/A") @@ -63,12 +164,24 @@ def __init__(self): self.generate_button = QPushButton("Generate") - layout.addLayout(previews_layout) - form_layout.addRow("Model Type:", self.model_input) + layout.addWidget(self.previews_widget) form_layout.addRow(self.autostart_checkbox) form_layout.addRow(self.generate_button) layout.addLayout(form_layout) + + # --- Results Area --- + results_group = QGroupBox("Generated Clips (Hover to play, Click button to insert)") + results_layout = QVBoxLayout() + self.results_list = QListWidget() + self.results_list.setFlow(QListWidget.Flow.LeftToRight) + self.results_list.setWrapping(True) + self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust) + self.results_list.setSpacing(10) + results_layout.addWidget(self.results_list) + results_group.setLayout(results_layout) + layout.addWidget(results_group) + layout.addStretch() self.generate_button.clicked.connect(self.generate) @@ -106,7 +219,7 @@ def handle_api_error(self, response, action="performing action"): def check_server_status(self): try: - requests.get(f"{API_BASE_URL}/api/latest_output", timeout=1) + requests.get(f"{API_BASE_URL}/api/outputs", timeout=1) self.status_updated.emit("AI Joiner: Connected to WanGP server.") return True except requests.exceptions.ConnectionError: @@ -114,11 +227,10 @@ def check_server_status(self): QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.") return False - def generate(self): payload = {} - model_type = self.model_input.text().strip() + model_type = self.model_type_to_use if model_type: payload['model_type'] = model_type @@ -148,31 +260,44 @@ def generate(self): def poll_for_output(self): try: - response = requests.get(f"{API_BASE_URL}/api/latest_output") + response = requests.get(f"{API_BASE_URL}/api/outputs") if response.status_code == 200: data = response.json() - latest_path = data.get("latest_output_path") + output_files = data.get("outputs", []) - if latest_path and latest_path != self.last_known_output: - self.stop_polling() - self.last_known_output = latest_path - self.status_updated.emit(f"AI Joiner: New output received! Inserting clip...") - self.generation_complete.emit(latest_path) + new_files = set(output_files) - self.processed_files + if new_files: + self.status_updated.emit(f"AI Joiner: Received {len(new_files)} new clip(s).") + for file_path in sorted(list(new_files)): + self.add_result_item(file_path) + self.processed_files.add(file_path) + else: + self.status_updated.emit("AI Joiner: Polling for output...") else: self.status_updated.emit("AI Joiner: Polling for output...") except requests.exceptions.RequestException: self.status_updated.emit("AI Joiner: Polling... (Connection issue)") + + def add_result_item(self, video_path): + item_widget = VideoResultItemWidget(video_path, self.plugin) + list_item = QListWidgetItem(self.results_list) + list_item.setSizeHint(item_widget.sizeHint()) + self.results_list.addItem(list_item) + self.results_list.setItemWidget(list_item, item_widget) + + def clear_results(self): + self.results_list.clear() + self.processed_files.clear() class Plugin(VideoEditorPlugin): def initialize(self): self.name = "AI Frame Joiner" self.description = "Uses a local AI server to generate a video between two frames." - self.client_widget = WgpClientWidget() + self.client_widget = WgpClientWidget(self) self.dock_widget = None self.active_region = None self.temp_dir = None self.insert_on_new_track = False - self.client_widget.generation_complete.connect(self.insert_generated_clip) self.client_widget.status_updated.connect(self.update_main_status) def enable(self): @@ -202,24 +327,41 @@ def _cleanup_temp_dir(self): def _reset_state(self): self.active_region = None self.insert_on_new_track = False + self.client_widget.stop_polling() + self.client_widget.clear_results() self._cleanup_temp_dir() self.update_main_status("AI Joiner: Idle") self.client_widget.set_previews(None, None) + self.client_widget.model_type_to_use = None def on_timeline_context_menu(self, menu, event): region = self.app.timeline_widget.get_region_at_pos(event.pos()) if region: menu.addSeparator() - action = menu.addAction("Join Frames With AI") - action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) - - action_new_track = menu.addAction("Join Frames With AI (New Track)") - action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) + + start_sec, end_sec = region + start_data, _, _ = self.app.get_frame_data_at_time(start_sec) + end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + + if start_data and end_data: + join_action = menu.addAction("Join Frames With AI") + join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) + + join_action_new_track = menu.addAction("Join Frames With AI (New Track)") + join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) + + create_action = menu.addAction("Create Frames With AI") + create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False)) + + create_action_new_track = menu.addAction("Create Frames With AI (New Track)") + create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True)) def setup_generator_for_region(self, region, on_new_track=False): self._reset_state() + self.client_widget.model_type_to_use = "i2v_2_2" self.active_region = region self.insert_on_new_track = on_new_track + self.client_widget.previews_widget.setVisible(True) start_sec, end_sec = region if not self.client_widget.check_server_status(): @@ -268,6 +410,30 @@ def setup_generator_for_region(self, region, on_new_track=False): self.dock_widget.show() self.dock_widget.raise_() + def setup_creator_for_region(self, region, on_new_track=False): + self._reset_state() + self.client_widget.model_type_to_use = "t2v_2_2" + self.active_region = region + self.insert_on_new_track = on_new_track + self.client_widget.previews_widget.setVisible(False) + + start_sec, end_sec = region + if not self.client_widget.check_server_status(): + return + + self.client_widget.set_previews(None, None) + + duration_sec = end_sec - start_sec + self.client_widget.duration_input.setText(str(duration_sec)) + + self.client_widget.start_frame_input.clear() + self.client_widget.end_frame_input.clear() + + self.update_main_status(f"AI Creator: Ready for region {start_sec:.2f}s - {end_sec:.2f}s") + + self.dock_widget.show() + self.dock_widget.raise_() + def insert_generated_clip(self, video_path): from main import TimelineClip @@ -333,11 +499,14 @@ def complex_insertion_action(): self.app.prune_empty_tracks() self.update_main_status("AI clip inserted successfully.") + for i in range(self.client_widget.results_list.count()): + item = self.client_widget.results_list.item(i) + widget = self.client_widget.results_list.itemWidget(item) + if widget and widget.video_path == video_path: + self.client_widget.results_list.takeItem(i) + break + except Exception as e: error_message = f"AI Joiner Error during clip insertion/probing: {e}" self.update_main_status(error_message) - print(error_message) - - finally: - self._cleanup_temp_dir() - self.active_region = None \ No newline at end of file + print(error_message) \ No newline at end of file From deef0107e600743d060abb2a42b01f445eba2f81 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 06:38:04 +1100 Subject: [PATCH 111/155] add snapping between clips for resize --- videoeditor/main.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index d1bd5ce8d..8a5418ba8 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -574,9 +574,15 @@ def mouseMoveEvent(self, event: QMouseEvent): delta_x = event.pos().x() - self.resize_start_pos.x() time_delta = delta_x / self.pixels_per_second min_duration = 1.0 / self.project_fps - playhead_time = self.playhead_pos_sec snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second + snap_points = [self.playhead_pos_sec] + for clip in self.timeline.clips: + if clip.id == self.resizing_clip.id: continue + if linked_clip and clip.id == linked_clip.id: continue + snap_points.append(clip.timeline_start_sec) + snap_points.append(clip.timeline_end_sec) + media_props = self.window().media_properties.get(self.resizing_clip.source_path) source_duration = media_props['duration'] if media_props else float('inf') @@ -585,10 +591,12 @@ def mouseMoveEvent(self, event: QMouseEvent): original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec true_new_start_sec = original_start + time_delta - if abs(true_new_start_sec - playhead_time) < snap_time_delta: - new_start_sec = playhead_time - else: - new_start_sec = true_new_start_sec + + new_start_sec = true_new_start_sec + for snap_point in snap_points: + if abs(true_new_start_sec - snap_point) < snap_time_delta: + new_start_sec = snap_point + break if new_start_sec > original_start + original_duration - min_duration: new_start_sec = original_start + original_duration - min_duration @@ -621,11 +629,14 @@ def mouseMoveEvent(self, event: QMouseEvent): true_new_duration = original_duration + time_delta true_new_end_time = original_start + true_new_duration - - if abs(true_new_end_time - playhead_time) < snap_time_delta: - new_duration = playhead_time - original_start - else: - new_duration = true_new_duration + + new_end_time = true_new_end_time + for snap_point in snap_points: + if abs(true_new_end_time - snap_point) < snap_time_delta: + new_end_time = snap_point + break + + new_duration = new_end_time - original_start if new_duration < min_duration: new_duration = min_duration From da346301cbdfb94ff690c7b1c4ea80abb29781ae Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 07:28:48 +1100 Subject: [PATCH 112/155] fixed minimize closing all windows --- videoeditor/main.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 8a5418ba8..90e36fad5 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -1482,6 +1482,10 @@ def remove_track(self, track_type): command.undo() self.undo_stack.push(command) + def on_dock_visibility_changed(self, action, visible): + if self.isMinimized(): + return + action.setChecked(visible) def _create_menu_bar(self): menu_bar = self.menuBar() @@ -1533,7 +1537,7 @@ def _create_menu_bar(self): action = QAction(data['name'], self, checkable=True) if hasattr(data['widget'], 'visibilityChanged'): action.toggled.connect(data['widget'].setVisible) - data['widget'].visibilityChanged.connect(action.setChecked) + data['widget'].visibilityChanged.connect(lambda visible, a=action: self.on_dock_visibility_changed(a, visible)) else: action.toggled.connect(lambda checked, k=key: self.toggle_widget_visibility(k, checked)) @@ -1766,7 +1770,6 @@ def _save_settings(self): def _apply_loaded_settings(self): visibility_settings = self.settings.get("window_visibility", {}) for key, data in self.managed_widgets.items(): - if data.get('plugin'): continue is_visible = visibility_settings.get(key, False if key == 'project_media' else True) if data['widget'] is not self.preview_widget: data['widget'].setVisible(is_visible) @@ -2373,7 +2376,7 @@ def add_dock_widget(self, plugin_instance, widget, title, area=Qt.DockWidgetArea dock.setVisible(initial_visibility) action = QAction(title, self, checkable=True) action.toggled.connect(dock.setVisible) - dock.visibilityChanged.connect(action.setChecked) + dock.visibilityChanged.connect(lambda visible, a=action: self.on_dock_visibility_changed(a, visible)) action.setChecked(dock.isVisible()) self.windows_menu.addAction(action) self.managed_widgets[widget_key] = {'widget': dock, 'name': title, 'action': action, 'plugin': plugin_instance.name} From 8b51fc48d8d26d3a2ab6c80ec7119e5a37a7e503 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 07:33:49 +1100 Subject: [PATCH 113/155] fixed export with images and blank canvas --- videoeditor/main.py | 75 ++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 90e36fad5..d1cda9fa9 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -2277,40 +2277,51 @@ def export_video(self): w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration() sample_rate, channel_layout = '44100', 'stereo' - - input_streams = {} - last_video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi') - for i in range(self.timeline.num_video_tracks): - track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec) - if not track_clips: continue - - track_segments = [] - last_end = track_clips[0].timeline_start_sec - - for clip in track_clips: - gap = clip.timeline_start_sec - last_end - if gap > 0.01: - track_segments.append(ffmpeg.input(f'color=c=black@0.0:s={w}x{h}:r={fr_str}:d={gap}', f='lavfi').filter('format', pix_fmts='rgba')) - - if clip.source_path not in input_streams: - if clip.media_type == 'image': - input_streams[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps) - else: - input_streams[clip.source_path] = ffmpeg.input(clip.source_path) + video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi') + + all_video_clips = sorted( + [c for c in self.timeline.clips if c.track_type == 'video'], + key=lambda c: c.track_index + ) + input_nodes = {} + + for clip in all_video_clips: + if clip.source_path not in input_nodes: if clip.media_type == 'image': - v_seg = (input_streams[clip.source_path].video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS')) + input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps) else: - v_seg = (input_streams[clip.source_path].video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS')) - - v_seg = (v_seg.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black').filter('format', pix_fmts='rgba')) - track_segments.append(v_seg) - last_end = clip.timeline_end_sec + input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) - track_stream = ffmpeg.concat(*track_segments, v=1, a=0).filter('setpts', f'PTS-STARTPTS+{track_clips[0].timeline_start_sec}/TB') - last_video_stream = ffmpeg.overlay(last_video_stream, track_stream) - final_video = last_video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps) + clip_source_node = input_nodes[clip.source_path] + + if clip.media_type == 'image': + segment_stream = ( + clip_source_node.video + .trim(duration=clip.duration_sec) + .setpts('PTS-STARTPTS') + ) + else: + segment_stream = ( + clip_source_node.video + .trim(start=clip.clip_start_sec, duration=clip.duration_sec) + .setpts('PTS-STARTPTS') + ) + + processed_segment = ( + segment_stream + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + ) + + video_stream = ffmpeg.overlay( + video_stream, + processed_segment, + enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})' + ) + + final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps) track_audio_streams = [] for i in range(self.timeline.num_audio_tracks): @@ -2321,14 +2332,14 @@ def export_video(self): last_end = track_clips[0].timeline_start_sec for clip in track_clips: - if clip.source_path not in input_streams: - input_streams[clip.source_path] = ffmpeg.input(clip.source_path) + if clip.source_path not in input_nodes: + input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) gap = clip.timeline_start_sec - last_end if gap > 0.01: track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi')) - a_seg = input_streams[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS') + a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS') track_segments.append(a_seg) last_end = clip.timeline_end_sec From 836a172ff28b1397b3f42f39cf7fc0469058fb93 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 08:03:38 +1100 Subject: [PATCH 114/155] fix multi-track preview rendering --- videoeditor/main.py | 179 ++++++++++++++++++++++++++++++-------------- 1 file changed, 124 insertions(+), 55 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index d1cda9fa9..9cfca35b0 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -1258,7 +1258,6 @@ def __init__(self, project_to_load=None): self.project_height = 720 self.playback_timer = QTimer(self) self.playback_process = None - self.playback_clip = None self._setup_ui() self._connect_signals() @@ -1546,19 +1545,64 @@ def _create_menu_bar(self): def _start_playback_stream_at(self, time_sec): self._stop_playback_stream() - clip = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) - if not clip: return - self.playback_clip = clip - clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec - w, h = self.project_width, self.project_height + if not self.timeline.clips: + return + + w, h, fr = self.project_width, self.project_height, self.project_fps + total_dur = self.timeline.get_total_duration() + + video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr}:d={total_dur}', f='lavfi') + + all_video_clips = sorted( + [c for c in self.timeline.clips if c.track_type == 'video'], + key=lambda c: c.track_index + ) + + input_nodes = {} + for clip in all_video_clips: + if clip.source_path not in input_nodes: + if clip.media_type == 'image': + input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=fr) + else: + input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) + + clip_source_node = input_nodes[clip.source_path] + + if clip.media_type == 'image': + segment_stream = ( + clip_source_node.video + .trim(duration=clip.duration_sec) + .setpts('PTS-STARTPTS') + ) + else: + segment_stream = ( + clip_source_node.video + .trim(start=clip.clip_start_sec, duration=clip.duration_sec) + .setpts('PTS-STARTPTS') + ) + + processed_segment = ( + segment_stream + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + ) + + video_stream = ffmpeg.overlay( + video_stream, + processed_segment, + enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})' + ) + try: - args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time) - .filter('scale', w, h, force_original_aspect_ratio='decrease') - .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') - .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile()) + args = ( + video_stream + .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=fr, ss=time_sec) + .compile() + ) self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) except Exception as e: - print(f"Failed to start playback stream: {e}"); self._stop_playback_stream() + print(f"Failed to start playback stream: {e}") + self._stop_playback_stream() def _stop_playback_stream(self): if self.playback_process: @@ -1567,7 +1611,6 @@ def _stop_playback_stream(self): try: self.playback_process.wait(timeout=0.5) except subprocess.TimeoutExpired: self.playback_process.kill(); self.playback_process.wait() self.playback_process = None - self.playback_clip = None def _get_media_properties(self, file_path): """Probes a file to get its media properties. Returns a dict or None.""" @@ -1623,25 +1666,33 @@ def _probe_for_drag(self, file_path): return self._get_media_properties(file_path) def get_frame_data_at_time(self, time_sec): - clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) - if not clip_at_time: + clips_at_time = [ + c for c in self.timeline.clips + if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec + ] + + if not clips_at_time: return (None, 0, 0) + + clips_at_time.sort(key=lambda c: c.track_index, reverse=True) + clip_to_render = clips_at_time[0] + try: w, h = self.project_width, self.project_height - if clip_at_time.media_type == 'image': + if clip_to_render.media_type == 'image': out, _ = ( ffmpeg - .input(clip_at_time.source_path) + .input(clip_to_render.source_path) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') .run(capture_stdout=True, quiet=True) ) else: - clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec + clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec out, _ = ( ffmpeg - .input(clip_at_time.source_path, ss=clip_time) + .input(clip_to_render.source_path, ss=clip_time) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') @@ -1653,22 +1704,33 @@ def get_frame_data_at_time(self, time_sec): return (None, 0, 0) def get_frame_at_time(self, time_sec): - clip_at_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec), None) - black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) - if not clip_at_time: return black_pixmap + clips_at_time = [ + c for c in self.timeline.clips + if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec + ] + + black_pixmap = QPixmap(self.project_width, self.project_height) + black_pixmap.fill(QColor("black")) + + if not clips_at_time: + return black_pixmap + + clips_at_time.sort(key=lambda c: c.track_index, reverse=True) + clip_to_render = clips_at_time[0] + try: w, h = self.project_width, self.project_height - if clip_at_time.media_type == 'image': + if clip_to_render.media_type == 'image': out, _ = ( - ffmpeg.input(clip_at_time.source_path) + ffmpeg.input(clip_to_render.source_path) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') .run(capture_stdout=True, quiet=True) ) else: - clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec - out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time) + clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec + out, _ = (ffmpeg.input(clip_to_render.source_path, ss=clip_time) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') @@ -1676,10 +1738,16 @@ def get_frame_at_time(self, time_sec): image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888) return QPixmap.fromImage(image) - except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap + except ffmpeg.Error as e: + print(f"Error extracting frame: {e.stderr}") + return black_pixmap def seek_preview(self, time_sec): - self._stop_playback_stream() + if self.playback_timer.isActive(): + self.playback_timer.stop() + self._stop_playback_stream() + self.play_pause_button.setText("Play") + self.timeline_widget.playhead_pos_sec = time_sec self.timeline_widget.update() frame_pixmap = self.get_frame_at_time(time_sec) @@ -1688,16 +1756,35 @@ def seek_preview(self, time_sec): self.preview_widget.setPixmap(scaled_pixmap) def toggle_playback(self): - if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play") + if self.playback_timer.isActive(): + self.playback_timer.stop() + self._stop_playback_stream() + self.play_pause_button.setText("Play") else: if not self.timeline.clips: return - if self.timeline_widget.playhead_pos_sec >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_sec = 0.0 - self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause") + + playhead_time = self.timeline_widget.playhead_pos_sec + if playhead_time >= self.timeline.get_total_duration(): + playhead_time = 0.0 + self.timeline_widget.playhead_pos_sec = 0.0 + + self._start_playback_stream_at(playhead_time) + + if self.playback_process: + self.playback_timer.start(int(1000 / self.project_fps)) + self.play_pause_button.setText("Pause") + + def stop_playback(self): + self.playback_timer.stop() + self._stop_playback_stream() + self.play_pause_button.setText("Play") + self.seek_preview(0.0) - def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0.0) def step_frame(self, direction): if not self.timeline.clips: return - self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream() + self.playback_timer.stop() + self._stop_playback_stream() + self.play_pause_button.setText("Play") frame_duration = 1.0 / self.project_fps new_time = self.timeline_widget.playhead_pos_sec + (direction * frame_duration) self.seek_preview(max(0, min(new_time, self.timeline.get_total_duration()))) @@ -1705,33 +1792,13 @@ def step_frame(self, direction): def advance_playback_frame(self): frame_duration = 1.0 / self.project_fps new_time = self.timeline_widget.playhead_pos_sec + frame_duration - if new_time > self.timeline.get_total_duration(): self.stop_playback(); return + if new_time > self.timeline.get_total_duration(): + self.stop_playback() + return self.timeline_widget.playhead_pos_sec = new_time self.timeline_widget.update() - clip_at_new_time = next((c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec), None) - - if not clip_at_new_time: - self._stop_playback_stream() - black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) - scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - self.preview_widget.setPixmap(scaled_pixmap) - return - - if clip_at_new_time.media_type == 'image': - if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: - self._stop_playback_stream() - self.playback_clip = clip_at_new_time - frame_pixmap = self.get_frame_at_time(new_time) - if frame_pixmap: - scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - self.preview_widget.setPixmap(scaled_pixmap) - return - - if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: - self._start_playback_stream_at(new_time) - if self.playback_process: frame_size = self.project_width * self.project_height * 3 frame_bytes = self.playback_process.stdout.read(frame_size) @@ -1742,6 +1809,8 @@ def advance_playback_frame(self): self.preview_widget.setPixmap(scaled_pixmap) else: self._stop_playback_stream() + if self.playback_timer.isActive(): + self.stop_playback() def _load_settings(self): self.settings_file_was_loaded = False From 109a2aa0468f9e332bbaccb7ad3675cbddcfd10c Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 08:51:51 +1100 Subject: [PATCH 115/155] fix multi-track preview rendering (2) --- videoeditor/main.py | 178 +++++++++++++++----------------------------- 1 file changed, 59 insertions(+), 119 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 9cfca35b0..49f97c8cc 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -1258,6 +1258,7 @@ def __init__(self, project_to_load=None): self.project_height = 720 self.playback_timer = QTimer(self) self.playback_process = None + self.playback_clip = None self._setup_ui() self._connect_signals() @@ -1545,64 +1546,22 @@ def _create_menu_bar(self): def _start_playback_stream_at(self, time_sec): self._stop_playback_stream() - if not self.timeline.clips: + clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec] + if not clips_at_time: return + clip = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0] - w, h, fr = self.project_width, self.project_height, self.project_fps - total_dur = self.timeline.get_total_duration() - - video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr}:d={total_dur}', f='lavfi') - - all_video_clips = sorted( - [c for c in self.timeline.clips if c.track_type == 'video'], - key=lambda c: c.track_index - ) - - input_nodes = {} - for clip in all_video_clips: - if clip.source_path not in input_nodes: - if clip.media_type == 'image': - input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=fr) - else: - input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) - - clip_source_node = input_nodes[clip.source_path] - - if clip.media_type == 'image': - segment_stream = ( - clip_source_node.video - .trim(duration=clip.duration_sec) - .setpts('PTS-STARTPTS') - ) - else: - segment_stream = ( - clip_source_node.video - .trim(start=clip.clip_start_sec, duration=clip.duration_sec) - .setpts('PTS-STARTPTS') - ) - - processed_segment = ( - segment_stream - .filter('scale', w, h, force_original_aspect_ratio='decrease') - .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') - ) - - video_stream = ffmpeg.overlay( - video_stream, - processed_segment, - enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})' - ) - + self.playback_clip = clip + clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec + w, h = self.project_width, self.project_height try: - args = ( - video_stream - .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=fr, ss=time_sec) - .compile() - ) + args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time) + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile()) self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) except Exception as e: - print(f"Failed to start playback stream: {e}") - self._stop_playback_stream() + print(f"Failed to start playback stream: {e}"); self._stop_playback_stream() def _stop_playback_stream(self): if self.playback_process: @@ -1611,6 +1570,7 @@ def _stop_playback_stream(self): try: self.playback_process.wait(timeout=0.5) except subprocess.TimeoutExpired: self.playback_process.kill(); self.playback_process.wait() self.playback_process = None + self.playback_clip = None def _get_media_properties(self, file_path): """Probes a file to get its media properties. Returns a dict or None.""" @@ -1666,33 +1626,26 @@ def _probe_for_drag(self, file_path): return self._get_media_properties(file_path) def get_frame_data_at_time(self, time_sec): - clips_at_time = [ - c for c in self.timeline.clips - if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec - ] - + clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec] if not clips_at_time: return (None, 0, 0) - - clips_at_time.sort(key=lambda c: c.track_index, reverse=True) - clip_to_render = clips_at_time[0] - + clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0] try: w, h = self.project_width, self.project_height - if clip_to_render.media_type == 'image': + if clip_at_time.media_type == 'image': out, _ = ( ffmpeg - .input(clip_to_render.source_path) + .input(clip_at_time.source_path) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') .run(capture_stdout=True, quiet=True) ) else: - clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec + clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec out, _ = ( ffmpeg - .input(clip_to_render.source_path, ss=clip_time) + .input(clip_at_time.source_path, ss=clip_time) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') @@ -1704,33 +1657,24 @@ def get_frame_data_at_time(self, time_sec): return (None, 0, 0) def get_frame_at_time(self, time_sec): - clips_at_time = [ - c for c in self.timeline.clips - if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec - ] - - black_pixmap = QPixmap(self.project_width, self.project_height) - black_pixmap.fill(QColor("black")) - + clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec] + black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) if not clips_at_time: return black_pixmap - - clips_at_time.sort(key=lambda c: c.track_index, reverse=True) - clip_to_render = clips_at_time[0] - + clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0] try: w, h = self.project_width, self.project_height - if clip_to_render.media_type == 'image': + if clip_at_time.media_type == 'image': out, _ = ( - ffmpeg.input(clip_to_render.source_path) + ffmpeg.input(clip_at_time.source_path) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') .run(capture_stdout=True, quiet=True) ) else: - clip_time = time_sec - clip_to_render.timeline_start_sec + clip_to_render.clip_start_sec - out, _ = (ffmpeg.input(clip_to_render.source_path, ss=clip_time) + clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec + out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time) .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') @@ -1738,16 +1682,10 @@ def get_frame_at_time(self, time_sec): image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888) return QPixmap.fromImage(image) - except ffmpeg.Error as e: - print(f"Error extracting frame: {e.stderr}") - return black_pixmap + except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap def seek_preview(self, time_sec): - if self.playback_timer.isActive(): - self.playback_timer.stop() - self._stop_playback_stream() - self.play_pause_button.setText("Play") - + self._stop_playback_stream() self.timeline_widget.playhead_pos_sec = time_sec self.timeline_widget.update() frame_pixmap = self.get_frame_at_time(time_sec) @@ -1756,35 +1694,16 @@ def seek_preview(self, time_sec): self.preview_widget.setPixmap(scaled_pixmap) def toggle_playback(self): - if self.playback_timer.isActive(): - self.playback_timer.stop() - self._stop_playback_stream() - self.play_pause_button.setText("Play") + if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play") else: if not self.timeline.clips: return - - playhead_time = self.timeline_widget.playhead_pos_sec - if playhead_time >= self.timeline.get_total_duration(): - playhead_time = 0.0 - self.timeline_widget.playhead_pos_sec = 0.0 - - self._start_playback_stream_at(playhead_time) - - if self.playback_process: - self.playback_timer.start(int(1000 / self.project_fps)) - self.play_pause_button.setText("Pause") - - def stop_playback(self): - self.playback_timer.stop() - self._stop_playback_stream() - self.play_pause_button.setText("Play") - self.seek_preview(0.0) + if self.timeline_widget.playhead_pos_sec >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_sec = 0.0 + self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause") + def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0.0) def step_frame(self, direction): if not self.timeline.clips: return - self.playback_timer.stop() - self._stop_playback_stream() - self.play_pause_button.setText("Play") + self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream() frame_duration = 1.0 / self.project_fps new_time = self.timeline_widget.playhead_pos_sec + (direction * frame_duration) self.seek_preview(max(0, min(new_time, self.timeline.get_total_duration()))) @@ -1792,13 +1711,36 @@ def step_frame(self, direction): def advance_playback_frame(self): frame_duration = 1.0 / self.project_fps new_time = self.timeline_widget.playhead_pos_sec + frame_duration - if new_time > self.timeline.get_total_duration(): - self.stop_playback() - return + if new_time > self.timeline.get_total_duration(): self.stop_playback(); return self.timeline_widget.playhead_pos_sec = new_time self.timeline_widget.update() + clips_at_new_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec] + clip_at_new_time = None + if clips_at_new_time: + clip_at_new_time = sorted(clips_at_new_time, key=lambda c: c.track_index, reverse=True)[0] + + if not clip_at_new_time: + self._stop_playback_stream() + black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) + scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + self.preview_widget.setPixmap(scaled_pixmap) + return + + if clip_at_new_time.media_type == 'image': + if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: + self._stop_playback_stream() + self.playback_clip = clip_at_new_time + frame_pixmap = self.get_frame_at_time(new_time) + if frame_pixmap: + scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + self.preview_widget.setPixmap(scaled_pixmap) + return + + if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: + self._start_playback_stream_at(new_time) + if self.playback_process: frame_size = self.project_width * self.project_height * 3 frame_bytes = self.playback_process.stdout.read(frame_size) @@ -1809,8 +1751,6 @@ def advance_playback_frame(self): self.preview_widget.setPixmap(scaled_pixmap) else: self._stop_playback_stream() - if self.playback_timer.isActive(): - self.stop_playback() def _load_settings(self): self.settings_file_was_loaded = False From de9ba22e1df8f140a0bacce13bfdcdeb5b2599da Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 08:59:41 +1100 Subject: [PATCH 116/155] speed up seeking --- videoeditor/main.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 49f97c8cc..996e8ba2a 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -1544,12 +1544,19 @@ def _create_menu_bar(self): data['action'] = action self.windows_menu.addAction(action) + def _get_topmost_video_clip_at(self, time_sec): + """Finds the video clip on the highest track at a specific time.""" + top_clip = None + for c in self.timeline.clips: + if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec: + if top_clip is None or c.track_index > top_clip.track_index: + top_clip = c + return top_clip + def _start_playback_stream_at(self, time_sec): self._stop_playback_stream() - clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec] - if not clips_at_time: - return - clip = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0] + clip = self._get_topmost_video_clip_at(time_sec) + if not clip: return self.playback_clip = clip clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec @@ -1626,10 +1633,9 @@ def _probe_for_drag(self, file_path): return self._get_media_properties(file_path) def get_frame_data_at_time(self, time_sec): - clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec] - if not clips_at_time: + clip_at_time = self._get_topmost_video_clip_at(time_sec) + if not clip_at_time: return (None, 0, 0) - clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0] try: w, h = self.project_width, self.project_height if clip_at_time.media_type == 'image': @@ -1657,11 +1663,9 @@ def get_frame_data_at_time(self, time_sec): return (None, 0, 0) def get_frame_at_time(self, time_sec): - clips_at_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec] black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) - if not clips_at_time: - return black_pixmap - clip_at_time = sorted(clips_at_time, key=lambda c: c.track_index, reverse=True)[0] + clip_at_time = self._get_topmost_video_clip_at(time_sec) + if not clip_at_time: return black_pixmap try: w, h = self.project_width, self.project_height if clip_at_time.media_type == 'image': @@ -1716,10 +1720,7 @@ def advance_playback_frame(self): self.timeline_widget.playhead_pos_sec = new_time self.timeline_widget.update() - clips_at_new_time = [c for c in self.timeline.clips if c.track_type == 'video' and c.timeline_start_sec <= new_time < c.timeline_end_sec] - clip_at_new_time = None - if clips_at_new_time: - clip_at_new_time = sorted(clips_at_new_time, key=lambda c: c.track_index, reverse=True)[0] + clip_at_new_time = self._get_topmost_video_clip_at(new_time) if not clip_at_new_time: self._stop_playback_stream() From 153c858d4ff60fa61ebfb437792eda6b6c9fbc90 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 09:40:51 +1100 Subject: [PATCH 117/155] add select clips, delete key and left right arrow keys --- videoeditor/main.py | 82 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 16 deletions(-) diff --git a/videoeditor/main.py b/videoeditor/main.py index 996e8ba2a..4bd3662c6 100644 --- a/videoeditor/main.py +++ b/videoeditor/main.py @@ -11,7 +11,7 @@ QScrollArea, QFrame, QProgressBar, QDialog, QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox) from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction, - QPixmap, QImage, QDrag, QCursor) + QPixmap, QImage, QDrag, QCursor, QKeyEvent) from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread, pyqtSignal, QTimer, QByteArray, QMimeData) @@ -87,6 +87,7 @@ class TimelineWidget(QWidget): split_requested = pyqtSignal(object) delete_clip_requested = pyqtSignal(object) + delete_clips_requested = pyqtSignal(list) playhead_moved = pyqtSignal(float) split_region_requested = pyqtSignal(list) split_all_regions_requested = pyqtSignal(list) @@ -114,7 +115,9 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.setMinimumHeight(300) self.setMouseTracking(True) self.setAcceptDrops(True) + self.setFocusPolicy(Qt.FocusPolicy.StrongFocus) self.selection_regions = [] + self.selected_clips = set() self.dragging_clip = None self.dragging_linked_clip = None self.dragging_playhead = False @@ -404,6 +407,12 @@ def draw_tracks_and_clips(self, painter): color = QColor("#5A9") if self.dragging_clip and self.dragging_clip.id == clip.id else base_color painter.fillRect(clip_rect, color) + + if clip.id in self.selected_clips: + pen = QPen(QColor(255, 255, 0, 220), 2) + painter.setPen(pen) + painter.drawRect(clip_rect) + painter.setPen(QPen(QColor("#FFF"), 1)) font = QFont("Arial", 10) painter.setFont(font) @@ -502,11 +511,13 @@ def mousePressEvent(self, event: QMouseEvent): return if event.button() == Qt.MouseButton.LeftButton: + self.setFocus() + self.dragging_clip = None self.dragging_linked_clip = None self.dragging_playhead = False - self.dragging_selection_region = None self.creating_selection_region = False + self.dragging_selection_region = None self.resizing_clip = None self.resize_edge = None self.drag_original_clip_states.clear() @@ -528,22 +539,36 @@ def mousePressEvent(self, event: QMouseEvent): self.update() return + clicked_clip = None for clip in reversed(self.timeline.clips): - clip_rect = self.get_clip_rect(clip) - if clip_rect.contains(QPointF(event.pos())): - self.dragging_clip = clip + if self.get_clip_rect(clip).contains(QPointF(event.pos())): + clicked_clip = clip + break + + if clicked_clip: + is_ctrl_pressed = bool(event.modifiers() & Qt.KeyboardModifier.ControlModifier) + + if clicked_clip.id in self.selected_clips: + if is_ctrl_pressed: + self.selected_clips.remove(clicked_clip.id) + else: + if not is_ctrl_pressed: + self.selected_clips.clear() + self.selected_clips.add(clicked_clip.id) + + if clicked_clip.id in self.selected_clips: + self.dragging_clip = clicked_clip self.drag_start_state = self.window()._get_current_timeline_state() - self.drag_original_clip_states[clip.id] = (clip.timeline_start_sec, clip.track_index) + self.drag_original_clip_states[clicked_clip.id] = (clicked_clip.timeline_start_sec, clicked_clip.track_index) - self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id), None) + self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clicked_clip.group_id and c.id != clicked_clip.id), None) if self.dragging_linked_clip: self.drag_original_clip_states[self.dragging_linked_clip.id] = \ (self.dragging_linked_clip.timeline_start_sec, self.dragging_linked_clip.track_index) - self.drag_start_pos = event.pos() - break - - if not self.dragging_clip: + + else: + self.selected_clips.clear() region_to_drag = self.get_region_at_pos(event.pos()) if region_to_drag: self.dragging_selection_region = region_to_drag @@ -1108,6 +1133,19 @@ def clear_all_regions(self): self.selection_regions.clear() self.update() + def keyPressEvent(self, event: QKeyEvent): + if event.key() == Qt.Key.Key_Delete or event.key() == Qt.Key.Key_Backspace: + if self.selected_clips: + clips_to_delete = [c for c in self.timeline.clips if c.id in self.selected_clips] + if clips_to_delete: + self.delete_clips_requested.emit(clips_to_delete) + elif event.key() == Qt.Key.Key_Left: + self.window().step_frame(-1) + elif event.key() == Qt.Key.Key_Right: + self.window().step_frame(1) + else: + super().keyPressEvent(event) + class SettingsDialog(QDialog): def __init__(self, parent_settings, parent=None): @@ -1352,6 +1390,7 @@ def _connect_signals(self): self.timeline_widget.split_requested.connect(self.split_clip_at_playhead) self.timeline_widget.delete_clip_requested.connect(self.delete_clip) + self.timeline_widget.delete_clips_requested.connect(self.delete_clips) self.timeline_widget.playhead_moved.connect(self.seek_preview) self.timeline_widget.split_region_requested.connect(self.on_split_region) self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions) @@ -2066,18 +2105,29 @@ def split_clip_at_playhead(self, clip_to_split=None): def delete_clip(self, clip_to_delete): + self.delete_clips([clip_to_delete]) + + def delete_clips(self, clips_to_delete): + if not clips_to_delete: return + old_state = self._get_current_timeline_state() - linked_clip = next((c for c in self.timeline.clips if c.group_id == clip_to_delete.group_id and c.id != clip_to_delete.id), None) + ids_to_remove = set() + for clip in clips_to_delete: + ids_to_remove.add(clip.id) + linked_clips = [c for c in self.timeline.clips if c.group_id == clip.group_id and c.id != clip.id] + for lc in linked_clips: + ids_to_remove.add(lc.id) + + self.timeline.clips = [c for c in self.timeline.clips if c.id not in ids_to_remove] + self.timeline_widget.selected_clips.clear() - if clip_to_delete in self.timeline.clips: self.timeline.clips.remove(clip_to_delete) - if linked_clip and linked_clip in self.timeline.clips: self.timeline.clips.remove(linked_clip) - new_state = self._get_current_timeline_state() - command = TimelineStateChangeCommand("Delete Clip", self.timeline, *old_state, *new_state) + command = TimelineStateChangeCommand(f"Delete {len(clips_to_delete)} Clip(s)", self.timeline, *old_state, *new_state) command.undo() self.undo_stack.push(command) self.prune_empty_tracks() + self.timeline_widget.update() def unlink_clip_pair(self, clip_to_unlink): old_state = self._get_current_timeline_state() From 69bb8221f7da8edf8e8c92c43ebdcc63e6bd4e1e Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 21:13:25 +1100 Subject: [PATCH 118/155] finally moved from server/client to a native plugin --- README.md | 287 +--- videoeditor/plugins.py => plugins.py | 13 +- main.py => plugins/wan2gp/main.py | 1311 +++++++------------ videoeditor/undo.py => undo.py | 0 videoeditor/main.py => videoeditor.py | 3 +- videoeditor/plugins/ai_frame_joiner/main.py | 512 -------- 6 files changed, 518 insertions(+), 1608 deletions(-) rename videoeditor/plugins.py => plugins.py (96%) rename main.py => plugins/wan2gp/main.py (69%) rename videoeditor/undo.py => undo.py (100%) rename videoeditor/main.py => videoeditor.py (99%) delete mode 100644 videoeditor/plugins/ai_frame_joiner/main.py diff --git a/README.md b/README.md index 6464a18f0..3726bd212 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi ## Features -- **Inline AI Video Generation With WAN2GP**: Select a region to join and it will bring up a desktop port of WAN2GP for you to generate a video inline using the start and end frames in the selected region. +- **Inline AI Video Generation With WAN2GP**: Select a region to join and it will bring up a desktop port of WAN2GP for you to generate a video inline using the start and end frames in the selected region. You can also create frames in the selected region. - **Multi-Track Timeline**: Arrange video and audio clips on separate tracks. - **Project Management**: Create, save, and load projects in a `.json` format. - **Clip Operations**: @@ -12,6 +12,7 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi - Split clips at the playhead. - Create selection regions for advanced operations. - Join/remove content within selected regions across all tracks. + - Link/Unlink audio tracks from video - **Real-time Preview**: A video preview window with playback controls (Play, Pause, Stop, Frame-by-frame stepping). - **Dynamic Track Management**: Add or remove video and audio tracks as needed. - **FFmpeg Integration**: @@ -23,286 +24,30 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi ## Installation -**Follow the standard installation steps for WAN2GP below** - -**Run the server:** -```bash -python main.py -``` - -**Run the video editor:** ```bash -cd videoeditor -python main.py -``` - -## Screenshots -image -image -image -image - - - - - -# What follows is the standard WanGP Readme - ------ -

-WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor -

- -WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with: -- Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models) -- Support for old GPUs (RTX 10XX, 20xx, ...) -- Very Fast on the latest GPUs -- Easy to use Full Web based interface -- Auto download of the required model adapted to your specific architecture -- Tools integrated to facilitate Video Generation : Mask Editor, Prompt Enhancer, Temporal and Spatial Generation, MMAudio, Video Browser, Pose / Depth / Flow extractor -- Loras Support to customize each model -- Queuing system : make your shopping list of videos to generate and come back later - -**Discord Server to get Help from Other Users and show your Best Videos:** https://discord.gg/g7efUW9jGV - -**Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep - -## 🔥 Latest Updates : -### October 6 2025: WanGP v8.994 - A few last things before the Big Unknown ... - -This new version hasn't any new model... - -...but temptation to upgrade will be high as it contains a few Loras related features that may change your Life: -- **Ready to use Loras Accelerators Profiles** per type of model that you can apply on your current *Generation Settings*. Next time I will recommend a *Lora Accelerator*, it will be only one click away. And best of all of the required Loras will be downloaded automatically. When you apply an *Accelerator Profile*, input fields like the *Number of Denoising Steps* *Activated Loras*, *Loras Multipliers* (such as "1;0 0;1" ...) will be automatically filled. However your video specific fields will be preserved, so it will be easy to switch between Profiles to experiment. With *WanGP 8.993*, the *Accelerator Loras* are now merged with *Non Accelerator Loras". Things are getting too easy... - -- **Embedded Loras URL** : WanGP will now try to remember every Lora URLs it sees. For instance if someone sends you some settings that contain Loras URLs or you extract the Settings of Video generated by a friend with Loras URLs, these URLs will be automatically added to *WanGP URL Cache*. Conversely everything you will share (Videos, Settings, Lset files) will contain the download URLs if they are known. You can also download directly a Lora in WanGP by using the *Download Lora* button a the bottom. The Lora will be immediatly available and added to WanGP lora URL cache. This will work with *Hugging Face* as a repository. Support for CivitAi will come as soon as someone will nice enough to post a GitHub PR ... - -- **.lset file** supports embedded Loras URLs. It has never been easier to share a Lora with a friend. As a reminder a .lset file can be created directly from *WanGP Web Interface* and it contains a list of Loras and their multipliers, a Prompt and Instructions how to use these loras (like the Lora's *Trigger*). So with embedded Loras URL, you can send an .lset file by email or share it on discord: it is just a 1 KB tiny text, but with it other people will be able to use Gigabytes Loras as these will be automatically downloaded. - -I have created the new Discord Channel **share-your-settings** where you can post your *Settings* or *Lset files*. I will be pleased to add new Loras Accelerators in the list of WanGP *Accelerators Profiles if you post some good ones there. - -*With the 8.993 update*, I have added support for **Scaled FP8 format**. As a sample case, I have created finetunes for the **Wan 2.2 PalinGenesis** Finetune which is quite popular recently. You will find it in 3 flavors : *t2v*, *i2v* and *Lightning Accelerated for t2v*. - -The *Scaled FP8 format* is widely used as it the format used by ... *ComfyUI*. So I except a flood of Finetunes in the *share-your-finetune* channel. If not it means this feature was useless and I will remove it 😈😈😈 - -Not enough Space left on your SSD to download more models ? Would like to reuse Scaled FP8 files in your ComfyUI Folder without duplicating them ? Here comes *WanGP 8.994* **Multiple Checkpoints Folders** : you just need to move the files into different folders / hard drives or reuse existing folders and let know WanGP about it in the *Config Tab* and WanGP will be able to put all the parts together. - -Last but not least the Lora's documentation has been updated. - -*update 8.991*: full power of *Vace Lynx* unleashed with new combinations such as Landscape + Face / Clothes + Face / Injectd Frame (Start/End frames/...) + Face -*update 8.992*: optimized gen with Lora, should be 10% faster if many loras -*update 8.993*: Support for *Scaled FP8* format and samples *Paligenesis* finetunes, merged Loras Accelerators and Non Accelerators -*update 8.994*: Added custom checkpoints folders - -### September 30 2025: WanGP v8.9 - Combinatorics - -This new version of WanGP introduces **Wan 2.1 Lynx** the best Control Net so far to transfer *Facial Identity*. You will be amazed to recognize your friends even with a completely different hair style. Congrats to the *Byte Dance team* for this achievement. Lynx works quite with well *Fusionix t2v* 10 steps. - -*WanGP 8.9* also illustrate how existing WanGP features can be easily combined with new models. For instance with *Lynx* you will get out of the box *Video to Video* and *Image/Text to Image*. - -Another fun combination is *Vace* + *Lynx*, which works much better than *Vace StandIn*. I have added sliders to change the weight of Vace & Lynx to allow you to tune the effects. - - -### September 28 2025: WanGP v8.76 - ~~Here Are Two Three New Contenders in the Vace Arena !~~ The Never Ending Release - -So in ~~today's~~ this release you will find two Wannabe Vace that covers each only a subset of Vace features but offers some interesting advantages: -- **Wan 2.2 Animate**: this model is specialized in *Body Motion* and *Facial Motion transfers*. It does that very well. You can use this model to either *Replace* a person in an in Video or *Animate* the person of your choice using an existing *Pose Video* (remember *Animate Anyone* ?). By default it will keep the original soundtrack. *Wan 2.2 Animate* seems to be under the hood a derived i2v model and should support the corresponding Loras Accelerators (for instance *FusioniX t2v*). Also as a WanGP exclusivity, you will find support for *Outpainting*. - -In order to use Wan 2.2 Animate you will need first to stop by the *Mat Anyone* embedded tool, to extract the *Video Mask* of the person from which you want to extract the motion. - -With version WanGP 8.74, there is an extra option that allows you to apply *Relighting* when Replacing a person. Also, you can now Animate a person without providing a Video Mask to target the source of the motion (with the risk it will be less precise) - -For those of you who have a mask halo effect when Animating a character I recommend trying *SDPA attention* and to use the *FusioniX i2v* lora. If this issue persists (this will depend on the control video) you have now a choice of the two *Animate Mask Options* in *WanGP 8.76*. The old masking option which was a WanGP exclusive has been renamed *See Through Mask* because the background behind the animated character was preserved but this creates sometime visual artifacts. The new option which has the shorter name is what you may find elsewhere online. As it uses internally a much larger mask, there is no halo. However the immediate background behind the character is not preserved and may end completely different. - -- **Lucy Edit**: this one claims to be a *Nano Banana* for Videos. Give it a video and asks it to change it (it is specialized in clothes changing) and voila ! The nice thing about it is that is it based on the *Wan 2.2 5B* model and therefore is very fast especially if you the *FastWan* finetune that is also part of the package. - -Also because I wanted to spoil you: -- **Qwen Edit Plus**: also known as the *Qwen Edit 25th September Update* which is specialized in combining multiple Objects / People. There is also a new support for *Pose transfer* & *Recolorisation*. All of this made easy to use in WanGP. You will find right now only the quantized version since HF crashes when uploading the unquantized version. - -- **T2V Video 2 Video Masking**: ever wanted to apply a Lora, a process (for instance Upsampling) or a Text Prompt on only a (moving) part of a Source Video. Look no further, I have added *Masked Video 2 Video* (which works also in image2image) in the *Text 2 Video* models. As usual you just need to use *Matanyone* to creatre the mask. - - -*Update 8.71*: fixed Fast Lucy Edit that didnt contain the lora -*Update 8.72*: shadow drop of Qwen Edit Plus -*Update 8.73*: Qwen Preview & InfiniteTalk Start image -*Update 8.74*: Animate Relighting / Nomask mode , t2v Masked Video to Video -*Update 8.75*: REDACTED -*Update 8.76*: Alternate Animate masking that fixes the mask halo effect that some users have - -### September 15 2025: WanGP v8.6 - Attack of the Clones - -- The long awaited **Vace for Wan 2.2** is at last here or maybe not: it has been released by the *Fun Team* of *Alibaba* and it is not official. You can play with the vanilla version (**Vace Fun**) or with the one accelerated with Loras (**Vace Fan Cocktail**) - -- **First Frame / Last Frame for Vace** : Vace models are so powerful that they could do *First frame / Last frame* since day one using the *Injected Frames* feature. However this required to compute by hand the locations of each end frame since this feature expects frames positions. I made it easier to compute these locations by using the "L" alias : - -For a video Gen from scratch *"1 L L L"* means the 4 Injected Frames will be injected like this: frame no 1 at the first position, the next frame at the end of the first window, then the following frame at the end of the next window, and so on .... -If you *Continue a Video* , you just need *"L L L"* since the first frame is the last frame of the *Source Video*. In any case remember that numeral frames positions (like "1") are aligned by default to the beginning of the source window, so low values such as 1 will be considered in the past unless you change this behaviour in *Sliding Window Tab/ Control Video, Injected Frames aligment*. - -- **Qwen Edit Inpainting** exists now in two versions: the original version of the previous release and a Lora based version. Each version has its pros and cons. For instance the Lora version supports also **Outpainting** ! However it tends to change slightly the original image even outside the outpainted area. - -- **Better Lipsync with all the Audio to Video models**: you probably noticed that *Multitalk*, *InfiniteTalk* or *Hunyuan Avatar* had so so lipsync when the audio provided contained some background music. The problem should be solved now thanks to an automated background music removal all done by IA. Don't worry you will still hear the music as it is added back in the generated Video. - -### September 11 2025: WanGP v8.5/8.55 - Wanna be a Cropper or a Painter ? - -I have done some intensive internal refactoring of the generation pipeline to ease support of existing models or add new models. Nothing really visible but this makes WanGP is little more future proof. - -Otherwise in the news: -- **Cropped Input Image Prompts**: as quite often most *Image Prompts* provided (*Start Image, Input Video, Reference Image, Control Video, ...*) rarely matched your requested *Output Resolution*. In that case I used the resolution you gave either as a *Pixels Budget* or as an *Outer Canvas* for the Generated Video. However in some occasion you really want the requested Output Resolution and nothing else. Besides some models deliver much better Generations if you stick to one of their supported resolutions. In order to address this need I have added a new Output Resolution choice in the *Configuration Tab*: **Dimensions Correspond to the Ouput Weight & Height as the Prompt Images will be Cropped to fit Exactly these dimensins**. In short if needed the *Input Prompt Images* will be cropped (centered cropped for the moment). You will see this can make quite a difference for some models - -- *Qwen Edit* has now a new sub Tab called **Inpainting**, that lets you target with a brush which part of the *Image Prompt* you want to modify. This is quite convenient if you find that Qwen Edit modifies usually too many things. Of course, as there are more constraints for Qwen Edit don't be surprised if sometime it will return the original image unchanged. A piece of advise: describe in your *Text Prompt* where (for instance *left to the man*, *top*, ...) the parts that you want to modify are located. - -The mask inpainting is fully compatible with *Matanyone Mask generator*: generate first an *Image Mask* with Matanyone, transfer it to the current Image Generator and modify the mask with the *Paint Brush*. Talking about matanyone I have fixed a bug that caused a mask degradation with long videos (now WanGP Matanyone is as good as the original app and still requires 3 times less VRAM) - -- This **Inpainting Mask Editor** has been added also to *Vace Image Mode*. Vace is probably still one of best Image Editor today. Here is a very simple & efficient workflow that do marvels with Vace: -Select *Vace Cocktail > Control Image Process = Perform Inpainting & Area Processed = Masked Area > Upload a Control Image, then draw your mask directly on top of the image & enter a text Prompt that describes the expected change > Generate > Below the Video Gallery click 'To Control Image' > Keep on doing more changes*. - -Doing more sophisticated thing Vace Image Editor works very well too: try Image Outpainting, Pose transfer, ... - -For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*" - -**update 8.55**: Flux Festival -- **Inpainting Mode** also added for *Flux Kontext* -- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available -- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768 - -Good luck with finding your way through all the Flux models names ! - -### September 5 2025: WanGP v8.4 - Take me to Outer Space -You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie. - -I have made it easier to do just that with *Qwen Edit* and *Wan*: -- **End Frames can now be combined with Continue a Video** (and not just a Start Frame) -- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window - -You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*) - -The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement. - -This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**. -So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close) - - -You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens). - -### September 2 2025: WanGP v8.31 - At last the pain stops - -- This single new feature should give you the strength to face all the potential bugs of this new release: -**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.** - -- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p). - -- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ... - - -*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs* - - -See full changelog: **[Changelog](docs/CHANGELOG.md)** - -## 📋 Table of Contents - -- [🚀 Quick Start](#-quick-start) -- [📦 Installation](#-installation) -- [🎯 Usage](#-usage) -- [📚 Documentation](#-documentation) -- [🔗 Related Projects](#-related-projects) - -## 🚀 Quick Start - -**One-click installation:** -- Get started instantly with [Pinokio App](https://pinokio.computer/) -- Use Redtash1 [One Click Install with Sage](https://github.com/Redtash1/Wan2GP-Windows-One-Click-Install-With-Sage) - -**Manual installation:** -```bash -git clone https://github.com/deepbeepmeep/Wan2GP.git +git clone https://github.com/Tophness/Wan2GP.git cd Wan2GP +git checkout video_editor conda create -n wan2gp python=3.10.9 conda activate wan2gp pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 pip install -r requirements.txt ``` -**Run the application:** -```bash -python wgp.py -``` - -**Update the application:** -If using Pinokio use Pinokio to update otherwise: -Get in the directory where WanGP is installed and: -```bash -git pull -pip install -r requirements.txt -``` - -## 🐳 Docker: - -**For Debian-based systems (Ubuntu, Debian, etc.):** +## Usage +**Run the video editor:** ```bash -./run-docker-cuda-deb.sh +python videoeditor.py ``` -This automated script will: - -- Detect your GPU model and VRAM automatically -- Select optimal CUDA architecture for your GPU -- Install NVIDIA Docker runtime if needed -- Build a Docker image with all dependencies -- Run WanGP with optimal settings for your hardware - -**Docker environment includes:** - -- NVIDIA CUDA 12.4.1 with cuDNN support -- PyTorch 2.6.0 with CUDA 12.4 support -- SageAttention compiled for your specific GPU architecture -- Optimized environment variables for performance (TF32, threading, etc.) -- Automatic cache directory mounting for faster subsequent runs -- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally - -**Supported GPUs:** RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more. - -## 📦 Installation - -For detailed installation instructions for different GPU generations: -- **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX - -## 🎯 Usage - -### Basic Usage -- **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage -- **[Models Overview](docs/MODELS.md)** - Available models and their capabilities - -### Advanced Features -- **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization -- **[Finetunes](docs/FINETUNES.md)** - Add manually new models to WanGP -- **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation -- **[Command Line Reference](docs/CLI.md)** - All available command line options - -## 📚 Documentation - -- **[Changelog](docs/CHANGELOG.md)** - Latest updates and version history -- **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions - -## 📚 Video Guides -- Nice Video that explain how to use Vace:\ -https://www.youtube.com/watch?v=FMo9oN2EAvE -- Another Vace guide:\ -https://www.youtube.com/watch?v=T5jNiEhf9xk - -## 🔗 Related Projects - -### Other Models for the GPU Poor -- **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators -- **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool -- **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux -- **[Cosmos1GP](https://github.com/deepbeepmeep/Cosmos1GP)** - Text to world generator and image/video to world -- **[OminiControlGP](https://github.com/deepbeepmeep/OminiControlGP)** - Flux-derived application for object transfer -- **[YuE GP](https://github.com/deepbeepmeep/YuEGP)** - Song generator with instruments and singer's voice - ---- +## Screenshots +image +image +image +image -

-Made with ❤️ by DeepBeepMeep -

+## Credits +The AI Video Generator plugin is built from a desktop port of WAN2GP by DeepBeepMeep. +See WAN2GP for more details. +https://github.com/deepbeepmeep/Wan2GP \ No newline at end of file diff --git a/videoeditor/plugins.py b/plugins.py similarity index 96% rename from videoeditor/plugins.py rename to plugins.py index 6e50f0963..f40f6adb1 100644 --- a/videoeditor/plugins.py +++ b/plugins.py @@ -16,8 +16,9 @@ class VideoEditorPlugin: Plugins should inherit from this class and be located in 'plugins/plugin_name/main.py'. The main class in main.py must be named 'Plugin'. """ - def __init__(self, app_instance): + def __init__(self, app_instance, wgp_module=None): self.app = app_instance + self.wgp = wgp_module self.name = "Unnamed Plugin" self.description = "No description provided." @@ -35,8 +36,9 @@ def disable(self): class PluginManager: """Manages the discovery, loading, and lifecycle of plugins.""" - def __init__(self, main_app): + def __init__(self, main_app, wgp_module=None): self.app = main_app + self.wgp = wgp_module self.plugins_dir = "plugins" self.plugins = {} # { 'plugin_name': {'instance': plugin_instance, 'enabled': False, 'module_path': path} } if not os.path.exists(self.plugins_dir): @@ -58,7 +60,7 @@ def discover_and_load_plugins(self): if hasattr(module, 'Plugin'): plugin_class = getattr(module, 'Plugin') - instance = plugin_class(self.app) + instance = plugin_class(self.app, self.wgp) instance.initialize() self.plugins[instance.name] = { 'instance': instance, @@ -70,6 +72,8 @@ def discover_and_load_plugins(self): print(f"Warning: {main_py_path} does not have a 'Plugin' class.") except Exception as e: print(f"Error loading plugin {plugin_name}: {e}") + import traceback + traceback.print_exc() def load_enabled_plugins_from_settings(self, enabled_plugins_list): """Enables plugins based on the loaded settings.""" @@ -263,4 +267,5 @@ def on_install_finished(self, message, success): self.status_label.setText(message) self.install_btn.setEnabled(True) if success: - self.url_input.clear() \ No newline at end of file + self.url_input.clear() + self.populate_list() \ No newline at end of file diff --git a/main.py b/plugins/wan2gp/main.py similarity index 69% rename from main.py rename to plugins/wan2gp/main.py index 773bcc9fa..ca201dcda 100644 --- a/main.py +++ b/plugins/wan2gp/main.py @@ -3,8 +3,11 @@ import threading import time import json -import re +import tempfile +import shutil +import uuid from unittest.mock import MagicMock +from pathlib import Path # --- Start of Gradio Hijacking --- # This block creates a mock Gradio module. When wgp.py is imported, @@ -16,12 +19,10 @@ class MockGradioComponent(MagicMock): """A smarter mock that captures constructor arguments.""" def __init__(self, *args, **kwargs): super().__init__(name=f"gr.{kwargs.get('elem_id', 'component')}") - # Store the kwargs so we can inspect them later self.kwargs = kwargs self.value = kwargs.get('value') self.choices = kwargs.get('choices') - # Mock chaining methods like .click(), .then(), etc. for method in ['then', 'change', 'click', 'input', 'select', 'upload', 'mount', 'launch', 'on', 'release']: setattr(self, method, lambda *a, **kw: self) @@ -34,153 +35,134 @@ def __exit__(self, exc_type, exc_val, exc_tb): class MockGradioError(Exception): pass -class MockGradioModule: # No longer inherits from MagicMock +class MockGradioModule: def __getattr__(self, name): if name == 'Error': return lambda *args, **kwargs: MockGradioError(*args) - # Nullify functions that show pop-ups if name in ['Info', 'Warning']: return lambda *args, **kwargs: print(f"Intercepted gr.{name}:", *args) return lambda *args, **kwargs: MockGradioComponent(*args, **kwargs) sys.modules['gradio'] = MockGradioModule() -sys.modules['gradio.gallery'] = MockGradioModule() # Also mock any submodules used +sys.modules['gradio.gallery'] = MockGradioModule() sys.modules['shared.gradio.gallery'] = MockGradioModule() # --- End of Gradio Hijacking --- -# Global placeholder for the wgp module. Will be None if import fails. -wgp = None - -# Load configuration and attempt to import wgp -MAIN_CONFIG_FILE = 'main_config.json' -main_config = {} - -def load_main_config(): - global main_config - try: - with open(MAIN_CONFIG_FILE, 'r') as f: - main_config = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - current_folder = os.path.dirname(os.path.abspath(sys.argv[0])) - main_config = {'wgp_path': current_folder} - -def save_main_config(): - global main_config - try: - with open(MAIN_CONFIG_FILE, 'w') as f: - json.dump(main_config, f, indent=4) - except Exception as e: - print(f"Error saving main_config.json: {e}") - -def setup_and_import_wgp(): - """Adds configured path to sys.path and tries to import wgp.""" - global wgp - wgp_path = main_config.get('wgp_path') - if wgp_path and os.path.isdir(wgp_path) and os.path.isfile(os.path.join(wgp_path, 'wgp.py')): - if wgp_path not in sys.path: - sys.path.insert(0, wgp_path) - try: - import wgp as wgp_module - wgp = wgp_module - return True - except ImportError as e: - print(f"Error: Failed to import wgp.py from the configured path '{wgp_path}'.\nDetails: {e}") - wgp = None - return False - else: - print("Info: WAN2GP folder path not set or invalid. Please configure it in File > Settings.") - wgp = None - return False - -# Load config and attempt import at script start -load_main_config() -wgp_loaded = setup_and_import_wgp() - -# Now import PyQt6 components +import ffmpeg + from PyQt6.QtWidgets import ( - QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, + QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox, QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem, QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem, - QMessageBox, QRadioButton, QDialog + QMessageBox, QRadioButton, QSizePolicy ) -from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QTimer, pyqtSlot -from PyQt6.QtGui import QPixmap, QDropEvent, QAction +from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QSize, QRectF +from PyQt6.QtGui import QPixmap, QImage, QDropEvent +from PyQt6.QtMultimedia import QMediaPlayer +from PyQt6.QtMultimediaWidgets import QVideoWidget from PIL.ImageQt import ImageQt +# Import the base plugin class from the main application's path +sys.path.append(str(Path(__file__).parent.parent.parent)) +from plugins import VideoEditorPlugin -class SettingsDialog(QDialog): - """Dialog to configure application settings like the wgp.py path.""" - def __init__(self, config, parent=None): +class VideoResultItemWidget(QWidget): + """A widget to display a generated video with a hover-to-play preview and insert button.""" + def __init__(self, video_path, plugin, parent=None): super().__init__(parent) - self.config = config - self.setWindowTitle("Settings") - self.setMinimumWidth(500) + self.video_path = video_path + self.plugin = plugin + self.app = plugin.app + self.duration = 0.0 + self.has_audio = False + self.setMinimumSize(200, 180) + self.setMaximumHeight(190) + layout = QVBoxLayout(self) - form_layout = QFormLayout() - - path_layout = QHBoxLayout() - self.wgp_path_edit = QLineEdit(self.config.get('wgp_path', '')) - self.wgp_path_edit.setPlaceholderText("Path to the folder containing wgp.py") - browse_btn = QPushButton("Browse...") - browse_btn.clicked.connect(self.browse_for_wgp_folder) - path_layout.addWidget(self.wgp_path_edit) - path_layout.addWidget(browse_btn) - - form_layout.addRow("WAN2GP Folder Path:", path_layout) - layout.addLayout(form_layout) - - button_layout = QHBoxLayout() - save_btn = QPushButton("Save") - save_restart_btn = QPushButton("Save and Restart") - cancel_btn = QPushButton("Cancel") - button_layout.addStretch() - button_layout.addWidget(save_btn) - button_layout.addWidget(save_restart_btn) - button_layout.addWidget(cancel_btn) - layout.addLayout(button_layout) - - save_btn.clicked.connect(self.save_and_close) - save_restart_btn.clicked.connect(self.save_and_restart) - cancel_btn.clicked.connect(self.reject) - - def browse_for_wgp_folder(self): - directory = QFileDialog.getExistingDirectory(self, "Select WAN2GP Folder") - if directory: - self.wgp_path_edit.setText(directory) - - def validate_path(self, path): - if not path or not os.path.isdir(path) or not os.path.isfile(os.path.join(path, 'wgp.py')): - QMessageBox.warning(self, "Invalid Path", "The selected folder does not contain 'wgp.py'. Please select the correct WAN2GP folder.") - return False - return True - - def _save_config(self): - path = self.wgp_path_edit.text() - if not self.validate_path(path): - return False - self.config['wgp_path'] = path - save_main_config() - return True - - def save_and_close(self): - if self._save_config(): - QMessageBox.information(self, "Settings Saved", "Settings have been saved. Please restart the application for changes to take effect.") - self.accept() - - def save_and_restart(self): - if self._save_config(): - self.parent().close() # Close the main window before restarting - os.execv(sys.executable, [sys.executable] + sys.argv) + layout.setContentsMargins(5, 5, 5, 5) + layout.setSpacing(5) + + self.media_player = QMediaPlayer() + self.video_widget = QVideoWidget() + self.video_widget.setFixedSize(160, 90) + self.media_player.setVideoOutput(self.video_widget) + self.media_player.setSource(QUrl.fromLocalFile(self.video_path)) + self.media_player.setLoops(QMediaPlayer.Loops.Infinite) + + self.info_label = QLabel(os.path.basename(video_path)) + self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.info_label.setWordWrap(True) + + self.insert_button = QPushButton("Insert into Timeline") + self.insert_button.clicked.connect(self.on_insert) + + h_layout = QHBoxLayout() + h_layout.addStretch() + h_layout.addWidget(self.video_widget) + h_layout.addStretch() + + layout.addLayout(h_layout) + layout.addWidget(self.info_label) + layout.addWidget(self.insert_button) + self.probe_video() + + def probe_video(self): + try: + probe = ffmpeg.probe(self.video_path) + self.duration = float(probe['format']['duration']) + self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', [])) + self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)") + except Exception as e: + self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}") + print(f"Error probing video {self.video_path}: {e}") + + def enterEvent(self, event): + super().enterEvent(event) + self.media_player.play() + if not self.plugin.active_region or self.duration == 0: return + start_sec, _ = self.plugin.active_region + timeline = self.app.timeline_widget + video_rect, audio_rect = None, None + x = timeline.sec_to_x(start_sec) + w = int(self.duration * timeline.pixels_per_second) + if self.plugin.insert_on_new_track: + video_y = timeline.TIMESCALE_HEIGHT + video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) + if self.has_audio: + audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT + audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) + else: + v_track_idx = 1 + visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx + video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT + video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) + if self.has_audio: + a_track_idx = 1 + visual_a_idx = a_track_idx - 1 + audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT + audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) + timeline.set_hover_preview_rects(video_rect, audio_rect) + + def leaveEvent(self, event): + super().leaveEvent(event) + self.media_player.pause() + self.media_player.setPosition(0) + self.app.timeline_widget.set_hover_preview_rects(None, None) + + def on_insert(self): + self.media_player.stop() + self.media_player.setSource(QUrl()) + self.media_player.setVideoOutput(None) + self.app.timeline_widget.set_hover_preview_rects(None, None) + self.plugin.insert_generated_clip(self.video_path) class QueueTableWidget(QTableWidget): - """A QTableWidget with drag-and-drop reordering for rows.""" rowsMoved = pyqtSignal(int, int) - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.setDragEnabled(True) @@ -195,14 +177,8 @@ def dropEvent(self, event: QDropEvent): source_row = self.currentRow() target_item = self.itemAt(event.position().toPoint()) dest_row = target_item.row() if target_item else self.rowCount() - - # Adjust destination row if moving down - if source_row < dest_row: - dest_row -=1 - - if source_row != dest_row: - self.rowsMoved.emit(source_row, dest_row) - + if source_row < dest_row: dest_row -=1 + if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row) event.acceptProposedAction() else: super().dropEvent(event) @@ -215,210 +191,78 @@ class Worker(QObject): output = pyqtSignal() finished = pyqtSignal() error = pyqtSignal(str) - - def __init__(self, state): + def __init__(self, plugin, state): super().__init__() + self.plugin = plugin + self.wgp = plugin.wgp self.state = state self._is_running = True self._last_progress_phase = None self._last_preview = None - def send_cmd(self, cmd, data=None): - if not self._is_running: - return - def run(self): def generation_target(): try: - for _ in wgp.process_tasks(self.state): - if self._is_running: - self.output.emit() - else: - break + for _ in self.wgp.process_tasks(self.state): + if self._is_running: self.output.emit() + else: break except Exception as e: import traceback print("Error in generation thread:") traceback.print_exc() - if "gradio.Error" in str(type(e)): - self.error.emit(str(e)) - else: - self.error.emit(f"An unexpected error occurred: {e}") + if "gradio.Error" in str(type(e)): self.error.emit(str(e)) + else: self.error.emit(f"An unexpected error occurred: {e}") finally: self._is_running = False - gen_thread = threading.Thread(target=generation_target, daemon=True) gen_thread.start() - while self._is_running: gen = self.state.get('gen', {}) - current_phase = gen.get("progress_phase") if current_phase and current_phase != self._last_progress_phase: self._last_progress_phase = current_phase - phase_name, step = current_phase total_steps = gen.get("num_inference_steps", 1) high_level_status = gen.get("progress_status", "") - - status_msg = wgp.merge_status_context(high_level_status, phase_name) - + status_msg = self.wgp.merge_status_context(high_level_status, phase_name) progress_args = [(step, total_steps), status_msg] self.progress.emit(progress_args) - preview_img = gen.get('preview') if preview_img is not None and preview_img is not self._last_preview: self._last_preview = preview_img self.preview.emit(preview_img) gen['preview'] = None - time.sleep(0.1) - gen_thread.join() self.finished.emit() -class ApiBridge(QObject): - """ - An object that lives in the main thread to receive signals from the API thread - and forward them to the MainWindow's slots. - """ - # --- CHANGE: Signal now includes model_type and duration_sec --- - generateSignal = pyqtSignal(object, object, object, object, bool) -class MainWindow(QMainWindow): - def __init__(self): +class WgpDesktopPluginWidget(QWidget): + def __init__(self, plugin): super().__init__() - - self.api_bridge = ApiBridge() - + self.plugin = plugin + self.wgp = plugin.wgp self.widgets = {} self.state = {} self.worker = None self.thread = None self.lora_map = {} self.full_resolution_choices = [] - self.latest_output_path = None + self.main_config = {} + self.processed_files = set() - self.setup_menu() - - if not wgp: - self.setWindowTitle("WanGP - Setup Required") - self.setGeometry(100, 100, 600, 200) - self.setup_placeholder_ui() - QTimer.singleShot(100, lambda: self.show_settings_dialog(first_time=True)) - else: - self.setWindowTitle(f"WanGP v{wgp.WanGP_version} - Qt Interface") - self.setGeometry(100, 100, 1400, 950) - self.setup_full_ui() - self.apply_initial_config() - self.connect_signals() - self.init_wgp_state() - - self.api_bridge.generateSignal.connect(self._api_generate) - - @pyqtSlot(str) - def _api_set_model(self, model_type): - """This slot is executed in the main GUI thread.""" - if not model_type or not wgp: return - - if model_type not in wgp.models_def: - print(f"API Error: Model type '{model_type}' is not a valid model.") - return - - if self.state.get('model_type') == model_type: - print(f"API: Model is already set to {model_type}.") - return - - self.update_model_dropdowns(model_type) - - self._on_model_changed() - - if self.state.get('model_type') == model_type: - print(f"API: Successfully set model to {model_type}.") - else: - print(f"API Error: Failed to set model to '{model_type}'. The model might be hidden by your current configuration.") - - @pyqtSlot(object, object, object, object, bool) - def _api_generate(self, start_frame, end_frame, duration_sec, model_type, start_generation): - """This slot is executed in the main GUI thread.""" - if model_type: - self._api_set_model(model_type) - if start_frame: - self.widgets['mode_s'].setChecked(True) - self.widgets['image_start'].setText(start_frame) - else: - self.widgets['mode_t'].setChecked(True) - self.widgets['image_start'].clear() - - if end_frame: - self.widgets['image_end_checkbox'].setChecked(True) - self.widgets['image_end'].setText(end_frame) - else: - self.widgets['image_end_checkbox'].setChecked(False) - self.widgets['image_end'].clear() - - if duration_sec is not None: - try: - duration = float(duration_sec) - base_fps = 16 - fps_text = self.widgets['force_fps'].itemText(0) - match = re.search(r'\((\d+)\s*fps\)', fps_text) - if match: - base_fps = int(match.group(1)) - - video_length_frames = int(duration * base_fps) - - self.widgets['video_length'].setValue(video_length_frames) - print(f"API: Calculated video length: {video_length_frames} frames for {duration:.2f}s @ {base_fps} effective FPS.") - - except (ValueError, TypeError) as e: - print(f"API Error: Invalid duration_sec '{duration_sec}': {e}") - - if start_generation: - self.generate_btn.click() - print("API: Generation started.") - else: - print("API: Parameters set without starting generation.") - - - def setup_menu(self): - menu_bar = self.menuBar() - file_menu = menu_bar.addMenu("&File") - - settings_action = QAction("&Settings", self) - settings_action.triggered.connect(self.show_settings_dialog) - file_menu.addAction(settings_action) - - def show_settings_dialog(self, first_time=False): - dialog = SettingsDialog(main_config, self) - if first_time: - dialog.setWindowTitle("Initial Setup: Configure WAN2GP Path") - dialog.exec() - - def setup_placeholder_ui(self): - main_widget = QWidget() - self.setCentralWidget(main_widget) - layout = QVBoxLayout(main_widget) - - placeholder_label = QLabel( - "

Welcome to WanGP

" - "

The path to your WAN2GP installation (the folder containing wgp.py) is not set.

" - "

Please go to File > Settings to configure the path.

" - ) - placeholder_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - placeholder_label.setWordWrap(True) - layout.addWidget(placeholder_label) - - def setup_full_ui(self): - main_widget = QWidget() - self.setCentralWidget(main_widget) - main_layout = QVBoxLayout(main_widget) + self.load_main_config() + self.setup_ui() + self.apply_initial_config() + self.connect_signals() + self.init_wgp_state() + def setup_ui(self): + main_layout = QVBoxLayout(self) self.header_info = QLabel("Header Info") main_layout.addWidget(self.header_info) - self.tabs = QTabWidget() main_layout.addWidget(self.tabs) - self.setup_generator_tab() self.setup_config_tab() @@ -431,18 +275,12 @@ def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1 container = QWidget() hbox = QHBoxLayout(container) hbox.setContentsMargins(0, 0, 0, 0) - slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal) slider.setRange(min_val, max_val) slider.setValue(int(initial_val * scale)) - value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}") value_label.setMinimumWidth(50) - - slider.valueChanged.connect( - lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}") - ) - + slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}")) hbox.addWidget(slider) hbox.addWidget(value_label) return container @@ -451,30 +289,20 @@ def _create_file_input(self, name, label_text): container = self.create_widget(QWidget, f"{name}_container") hbox = QHBoxLayout(container) hbox.setContentsMargins(0, 0, 0, 0) - line_edit = self.create_widget(QLineEdit, name) - line_edit.setReadOnly(False) # Allow user to paste paths line_edit.setPlaceholderText("No file selected or path pasted") - button = QPushButton("Browse...") - def open_dialog(): - # Allow selecting multiple files for reference images if "refs" in name: filenames, _ = QFileDialog.getOpenFileNames(self, f"Select {label_text}") - if filenames: - line_edit.setText(";".join(filenames)) + if filenames: line_edit.setText(";".join(filenames)) else: filename, _ = QFileDialog.getOpenFileName(self, f"Select {label_text}") - if filename: - line_edit.setText(filename) - + if filename: line_edit.setText(filename) button.clicked.connect(open_dialog) - clear_button = QPushButton("X") clear_button.setFixedWidth(30) clear_button.clicked.connect(lambda: line_edit.clear()) - hbox.addWidget(QLabel(f"{label_text}:")) hbox.addWidget(line_edit, 1) hbox.addWidget(button) @@ -485,25 +313,18 @@ def setup_generator_tab(self): gen_tab = QWidget() self.tabs.addTab(gen_tab, "Video Generator") gen_layout = QHBoxLayout(gen_tab) - left_panel = QWidget() left_layout = QVBoxLayout(left_panel) gen_layout.addWidget(left_panel, 1) - right_panel = QWidget() right_layout = QVBoxLayout(right_panel) gen_layout.addWidget(right_panel, 1) - - # Left Panel (Inputs) scroll_area = QScrollArea() scroll_area.setWidgetResizable(True) left_layout.addWidget(scroll_area) - options_widget = QWidget() scroll_area.setWidget(options_widget) options_layout = QVBoxLayout(options_widget) - - # Model Selection model_layout = QHBoxLayout() self.widgets['model_family'] = QComboBox() self.widgets['model_base_type_choice'] = QComboBox() @@ -513,36 +334,26 @@ def setup_generator_tab(self): model_layout.addWidget(self.widgets['model_base_type_choice'], 3) model_layout.addWidget(self.widgets['model_choice'], 3) options_layout.addLayout(model_layout) - - # Prompt options_layout.addWidget(QLabel("Prompt:")) self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100) options_layout.addWidget(self.widgets['prompt']) - options_layout.addWidget(QLabel("Negative Prompt:")) self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60) options_layout.addWidget(self.widgets['negative_prompt']) - - # Basic controls basic_group = QGroupBox("Basic Options") basic_layout = QFormLayout(basic_group) - res_container = QWidget() res_hbox = QHBoxLayout(res_container) res_hbox.setContentsMargins(0, 0, 0, 0) res_hbox.addWidget(self.create_widget(QComboBox, 'resolution_group'), 2) res_hbox.addWidget(self.create_widget(QComboBox, 'resolution'), 3) basic_layout.addRow("Resolution:", res_container) - basic_layout.addRow("Video Length:", self._create_slider_with_label('video_length', 1, 737, 81, 1.0, 0)) basic_layout.addRow("Inference Steps:", self._create_slider_with_label('num_inference_steps', 1, 100, 30, 1.0, 0)) basic_layout.addRow("Seed:", self.create_widget(QLineEdit, 'seed', '-1')) options_layout.addWidget(basic_group) - - # Generation Mode and Input Options mode_options_group = QGroupBox("Generation Mode & Input Options") mode_options_layout = QVBoxLayout(mode_options_group) - mode_hbox = QHBoxLayout() mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_t', "Text Prompt Only")) mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_s', "Start with Image")) @@ -550,15 +361,12 @@ def setup_generator_tab(self): mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_l', "Continue Last Video")) self.widgets['mode_t'].setChecked(True) mode_options_layout.addLayout(mode_hbox) - options_hbox = QHBoxLayout() options_hbox.addWidget(self.create_widget(QCheckBox, 'image_end_checkbox', "Use End Image")) options_hbox.addWidget(self.create_widget(QCheckBox, 'control_video_checkbox', "Use Control Video")) options_hbox.addWidget(self.create_widget(QCheckBox, 'ref_image_checkbox', "Use Reference Image(s)")) mode_options_layout.addLayout(options_hbox) options_layout.addWidget(mode_options_group) - - # Dynamic Inputs inputs_group = QGroupBox("Inputs") inputs_layout = QVBoxLayout(inputs_group) inputs_layout.addWidget(self._create_file_input('image_start', "Start Image")) @@ -571,16 +379,12 @@ def setup_generator_tab(self): denoising_row.addRow("Denoising Strength:", self._create_slider_with_label('denoising_strength', 0, 100, 50, 100.0, 2)) inputs_layout.addLayout(denoising_row) options_layout.addWidget(inputs_group) - - # Advanced controls self.advanced_group = self.create_widget(QGroupBox, 'advanced_group', "Advanced Options") self.advanced_group.setCheckable(True) self.advanced_group.setChecked(False) advanced_layout = QVBoxLayout(self.advanced_group) - advanced_tabs = self.create_widget(QTabWidget, 'advanced_tabs') advanced_layout.addWidget(advanced_tabs) - self._setup_adv_tab_general(advanced_tabs) self._setup_adv_tab_loras(advanced_tabs) self._setup_adv_tab_speed(advanced_tabs) @@ -589,10 +393,8 @@ def setup_generator_tab(self): self._setup_adv_tab_quality(advanced_tabs) self._setup_adv_tab_sliding_window(advanced_tabs) self._setup_adv_tab_misc(advanced_tabs) - options_layout.addWidget(self.advanced_group) - # Right Panel (Output & Queue) btn_layout = QHBoxLayout() self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate") self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue") @@ -601,32 +403,34 @@ def setup_generator_tab(self): btn_layout.addWidget(self.generate_btn) btn_layout.addWidget(self.add_to_queue_btn) right_layout.addLayout(btn_layout) - self.status_label = self.create_widget(QLabel, 'status_label', "Idle") right_layout.addWidget(self.status_label) self.progress_bar = self.create_widget(QProgressBar, 'progress_bar') right_layout.addWidget(self.progress_bar) - preview_group = self.create_widget(QGroupBox, 'preview_group', "Preview") preview_group.setCheckable(True) preview_group.setStyleSheet("QGroupBox { border: 1px solid #cccccc; }") preview_group_layout = QVBoxLayout(preview_group) - self.preview_image = self.create_widget(QLabel, 'preview_image', "") self.preview_image.setAlignment(Qt.AlignmentFlag.AlignCenter) self.preview_image.setMinimumSize(200, 200) - preview_group_layout.addWidget(self.preview_image) right_layout.addWidget(preview_group) - right_layout.addWidget(QLabel("Output:")) - self.output_gallery = self.create_widget(QListWidget, 'output_gallery') - right_layout.addWidget(self.output_gallery) + results_group = QGroupBox("Generated Clips") + results_group.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + results_layout = QVBoxLayout(results_group) + self.results_list = self.create_widget(QListWidget, 'results_list') + self.results_list.setFlow(QListWidget.Flow.LeftToRight) + self.results_list.setWrapping(True) + self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust) + self.results_list.setSpacing(10) + results_layout.addWidget(self.results_list) + right_layout.addWidget(results_group) right_layout.addWidget(QLabel("Queue:")) self.queue_table = self.create_widget(QueueTableWidget, 'queue_table') right_layout.addWidget(self.queue_table) - queue_btn_layout = QHBoxLayout() self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected") self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue") @@ -641,14 +445,11 @@ def _setup_adv_tab_general(self, tabs): tabs.addTab(tab, "General") layout = QFormLayout(tab) self.widgets['adv_general_layout'] = layout - guidance_group = QGroupBox("Guidance") guidance_layout = self.create_widget(QFormLayout, 'guidance_layout', guidance_group) guidance_layout.addRow("Guidance (CFG):", self._create_slider_with_label('guidance_scale', 10, 200, 5.0, 10.0, 1)) - self.widgets['guidance_phases_row_index'] = guidance_layout.rowCount() guidance_layout.addRow("Guidance Phases:", self.create_widget(QComboBox, 'guidance_phases')) - self.widgets['guidance2_row_index'] = guidance_layout.rowCount() guidance_layout.addRow("Guidance 2:", self._create_slider_with_label('guidance2_scale', 10, 200, 5.0, 10.0, 1)) self.widgets['guidance3_row_index'] = guidance_layout.rowCount() @@ -656,30 +457,24 @@ def _setup_adv_tab_general(self, tabs): self.widgets['switch_thresh_row_index'] = guidance_layout.rowCount() guidance_layout.addRow("Switch Threshold:", self._create_slider_with_label('switch_threshold', 0, 1000, 0, 1.0, 0)) layout.addRow(guidance_group) - nag_group = self.create_widget(QGroupBox, 'nag_group', "NAG (Negative Adversarial Guidance)") nag_layout = QFormLayout(nag_group) nag_layout.addRow("NAG Scale:", self._create_slider_with_label('NAG_scale', 10, 200, 1.0, 10.0, 1)) nag_layout.addRow("NAG Tau:", self._create_slider_with_label('NAG_tau', 10, 50, 3.5, 10.0, 1)) nag_layout.addRow("NAG Alpha:", self._create_slider_with_label('NAG_alpha', 0, 20, 0.5, 10.0, 1)) layout.addRow(nag_group) - self.widgets['solver_row_container'] = QWidget() solver_hbox = QHBoxLayout(self.widgets['solver_row_container']) solver_hbox.setContentsMargins(0,0,0,0) solver_hbox.addWidget(QLabel("Sampler Solver:")) solver_hbox.addWidget(self.create_widget(QComboBox, 'sample_solver')) layout.addRow(self.widgets['solver_row_container']) - self.widgets['flow_shift_row_index'] = layout.rowCount() layout.addRow("Shift Scale:", self._create_slider_with_label('flow_shift', 10, 250, 3.0, 10.0, 1)) - self.widgets['audio_guidance_row_index'] = layout.rowCount() layout.addRow("Audio Guidance:", self._create_slider_with_label('audio_guidance_scale', 10, 200, 4.0, 10.0, 1)) - self.widgets['repeat_generation_row_index'] = layout.rowCount() layout.addRow("Repeat Generations:", self._create_slider_with_label('repeat_generation', 1, 25, 1, 1.0, 0)) - combo = self.create_widget(QComboBox, 'multi_images_gen_type') combo.addItem("Generate all combinations", 0) combo.addItem("Match images and texts", 1) @@ -706,7 +501,6 @@ def _setup_adv_tab_speed(self, tabs): combo.addItem("Tea Cache", "tea") combo.addItem("Mag Cache", "mag") layout.addRow("Cache Type:", combo) - combo = self.create_widget(QComboBox, 'skip_steps_multiplier') combo.addItem("x1.5 speed up", 1.5) combo.addItem("x1.75 speed up", 1.75) @@ -725,13 +519,11 @@ def _setup_adv_tab_postproc(self, tabs): combo.addItem("Rife x2 frames/s", "rife2") combo.addItem("Rife x4 frames/s", "rife4") layout.addRow("Temporal Upsampling:", combo) - combo = self.create_widget(QComboBox, 'spatial_upsampling') combo.addItem("Disabled", "") combo.addItem("Lanczos x1.5", "lanczos1.5") combo.addItem("Lanczos x2.0", "lanczos2") layout.addRow("Spatial Upsampling:", combo) - layout.addRow("Film Grain Intensity:", self._create_slider_with_label('film_grain_intensity', 0, 100, 0, 100.0, 2)) layout.addRow("Film Grain Saturation:", self._create_slider_with_label('film_grain_saturation', 0, 100, 0.5, 100.0, 2)) @@ -751,7 +543,6 @@ def _setup_adv_tab_quality(self, tabs): tab = QWidget() tabs.addTab(tab, "Quality") layout = QVBoxLayout(tab) - slg_group = self.create_widget(QGroupBox, 'slg_group', "Skip Layer Guidance") slg_layout = QFormLayout(slg_group) slg_combo = self.create_widget(QComboBox, 'slg_switch') @@ -761,25 +552,20 @@ def _setup_adv_tab_quality(self, tabs): slg_layout.addRow("Start %:", self._create_slider_with_label('slg_start_perc', 0, 100, 10, 1.0, 0)) slg_layout.addRow("End %:", self._create_slider_with_label('slg_end_perc', 0, 100, 90, 1.0, 0)) layout.addWidget(slg_group) - quality_form = QFormLayout() self.widgets['quality_form_layout'] = quality_form - apg_combo = self.create_widget(QComboBox, 'apg_switch') apg_combo.addItem("OFF", 0) apg_combo.addItem("ON", 1) self.widgets['apg_switch_row_index'] = quality_form.rowCount() quality_form.addRow("Adaptive Projected Guidance:", apg_combo) - cfg_star_combo = self.create_widget(QComboBox, 'cfg_star_switch') cfg_star_combo.addItem("OFF", 0) cfg_star_combo.addItem("ON", 1) self.widgets['cfg_star_switch_row_index'] = quality_form.rowCount() quality_form.addRow("Classifier-Free Guidance Star:", cfg_star_combo) - self.widgets['cfg_zero_step_row_index'] = quality_form.rowCount() quality_form.addRow("CFG Zero below Layer:", self._create_slider_with_label('cfg_zero_step', -1, 39, -1, 1.0, 0)) - combo = self.create_widget(QComboBox, 'min_frames_if_references') combo.addItem("Disabled (1 frame)", 1) combo.addItem("Generate 5 frames", 5) @@ -795,7 +581,6 @@ def _setup_adv_tab_sliding_window(self, tabs): self.widgets['sliding_window_tab_index'] = tabs.count() tabs.addTab(tab, "Sliding Window") layout = QFormLayout(tab) - layout.addRow("Window Size:", self._create_slider_with_label('sliding_window_size', 5, 257, 129, 1.0, 0)) layout.addRow("Overlap:", self._create_slider_with_label('sliding_window_overlap', 1, 97, 5, 1.0, 0)) layout.addRow("Color Correction:", self._create_slider_with_label('sliding_window_color_correction_strength', 0, 100, 0, 100.0, 2)) @@ -807,23 +592,18 @@ def _setup_adv_tab_misc(self, tabs): tabs.addTab(tab, "Misc") layout = QFormLayout(tab) self.widgets['misc_layout'] = layout - riflex_combo = self.create_widget(QComboBox, 'RIFLEx_setting') riflex_combo.addItem("Auto", 0) riflex_combo.addItem("Always ON", 1) riflex_combo.addItem("Always OFF", 2) self.widgets['riflex_row_index'] = layout.rowCount() layout.addRow("RIFLEx Setting:", riflex_combo) - fps_combo = self.create_widget(QComboBox, 'force_fps') layout.addRow("Force FPS:", fps_combo) - profile_combo = self.create_widget(QComboBox, 'override_profile') profile_combo.addItem("Default Profile", -1) - for text, val in wgp.memory_profile_choices: - profile_combo.addItem(text.split(':')[0], val) + for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val) layout.addRow("Override Memory Profile:", profile_combo) - combo = self.create_widget(QComboBox, 'multi_prompts_gen_type') combo.addItem("Generate new Video per line", 0) combo.addItem("Use line for new Sliding Window", 1) @@ -833,19 +613,15 @@ def setup_config_tab(self): config_tab = QWidget() self.tabs.addTab(config_tab, "Configuration") main_layout = QVBoxLayout(config_tab) - self.config_status_label = QLabel("Apply changes for them to take effect. Some may require a restart.") main_layout.addWidget(self.config_status_label) - config_tabs = QTabWidget() main_layout.addWidget(config_tabs) - config_tabs.addTab(self._create_general_config_tab(), "General") config_tabs.addTab(self._create_performance_config_tab(), "Performance") config_tabs.addTab(self._create_extensions_config_tab(), "Extensions") config_tabs.addTab(self._create_outputs_config_tab(), "Outputs") config_tabs.addTab(self._create_notifications_config_tab(), "Notifications") - self.apply_config_btn = QPushButton("Apply Changes") self.apply_config_btn.clicked.connect(self._on_apply_config_changes) main_layout.addWidget(self.apply_config_btn) @@ -856,18 +632,15 @@ def _create_scrollable_form_tab(self): scroll_area.setWidgetResizable(True) layout = QVBoxLayout(tab_widget) layout.addWidget(scroll_area) - content_widget = QWidget() form_layout = QFormLayout(content_widget) scroll_area.setWidget(content_widget) - return tab_widget, form_layout def _create_config_combo(self, form_layout, label, key, choices, default_value): combo = QComboBox() - for text, data in choices: - combo.addItem(text, data) - index = combo.findData(wgp.server_config.get(key, default_value)) + for text, data in choices: combo.addItem(text, data) + index = combo.findData(self.wgp.server_config.get(key, default_value)) if index != -1: combo.setCurrentIndex(index) self.widgets[f'config_{key}'] = combo form_layout.addRow(label, combo) @@ -876,34 +649,27 @@ def _create_config_slider(self, form_layout, label, key, min_val, max_val, defau container = QWidget() hbox = QHBoxLayout(container) hbox.setContentsMargins(0,0,0,0) - slider = QSlider(Qt.Orientation.Horizontal) slider.setRange(min_val, max_val) slider.setSingleStep(step) - slider.setValue(wgp.server_config.get(key, default_value)) - + slider.setValue(self.wgp.server_config.get(key, default_value)) value_label = QLabel(str(slider.value())) value_label.setMinimumWidth(40) slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v))) - hbox.addWidget(slider) hbox.addWidget(value_label) - self.widgets[f'config_{key}'] = slider form_layout.addRow(label, container) def _create_config_checklist(self, form_layout, label, key, choices, default_value): list_widget = QListWidget() list_widget.setMinimumHeight(100) - current_values = wgp.server_config.get(key, default_value) + current_values = self.wgp.server_config.get(key, default_value) for text, data in choices: item = QListWidgetItem(text) item.setData(Qt.ItemDataRole.UserRole, data) item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) - if data in current_values: - item.setCheckState(Qt.CheckState.Checked) - else: - item.setCheckState(Qt.CheckState.Unchecked) + item.setCheckState(Qt.CheckState.Checked if data in current_values else Qt.CheckState.Unchecked) list_widget.addItem(item) self.widgets[f'config_{key}'] = list_widget form_layout.addRow(label, list_widget) @@ -919,48 +685,24 @@ def _create_config_textbox(self, form_layout, label, key, default_value, multi_l def _create_general_config_tab(self): tab, form = self._create_scrollable_form_tab() - - _, _, dropdown_choices = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False) - self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, wgp.transformer_types) - - self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [ - ("Two Levels (Family > Model)", 0), - ("Three Levels (Family > Base > Finetune)", 1) - ], 1) - - self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [ - ("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), - ("Dimensions are Output Width/Height (Cropped)", 2)], 0) - - self._create_config_combo(form, "Attention Type:", "attention_mode", [ - ("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), - ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto") - - self._create_config_combo(form, "Metadata Handling:", "metadata_type", [ - ("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata") - - self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [ - ("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], []) - - self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [ - ("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), - ("Keep last 20", 20), ("Keep last 30", 30)], 5) - + _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False) + self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types) + self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1) + self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0) + self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto") + self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata") + self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], []) + self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5) self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0) - - self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", - [(f"x{i}", i) for i in range(1, 8)], 1) - - checkpoints_paths_text = "\n".join(wgp.server_config.get("checkpoints_paths", wgp.fl.default_checkpoints_paths)) + self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1) + checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths)) checkpoints_textbox = QTextEdit() checkpoints_textbox.setPlainText(checkpoints_paths_text) checkpoints_textbox.setAcceptRichText(False) checkpoints_textbox.setMinimumHeight(60) self.widgets['config_checkpoints_paths'] = checkpoints_textbox form.addRow("Checkpoints Paths:", checkpoints_textbox) - self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default") - return tab def _create_performance_config_tab(self): @@ -974,13 +716,11 @@ def _create_performance_config_tab(self): self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl") self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0) self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1) - self._create_config_combo(form, "Memory Profile:", "profile", wgp.memory_profile_choices, wgp.profile_type.LowRAM_LowVRAM) + self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM) self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100) - release_ram_btn = QPushButton("Force Release Models from RAM") release_ram_btn.clicked.connect(self._on_release_ram) form.addRow(release_ram_btn) - return tab def _create_extensions_config_tab(self): @@ -1005,12 +745,12 @@ def _create_notifications_config_tab(self): return tab def init_wgp_state(self): + wgp = self.wgp initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type) - all_models, _, _ = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False) + dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types + _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) all_model_ids = [m[1] for m in all_models] - if initial_model not in all_model_ids: - initial_model = wgp.transformer_type - + if initial_model not in all_model_ids: initial_model = wgp.transformer_type state_dict = {} state_dict["model_filename"] = wgp.get_model_filename(initial_model, wgp.transformer_quantization, wgp.transformer_dtype_policy) state_dict["model_type"] = initial_model @@ -1019,49 +759,36 @@ def init_wgp_state(self): state_dict["last_model_per_type"] = wgp.server_config.get("last_model_per_type", {}) state_dict["last_resolution_per_group"] = wgp.server_config.get("last_resolution_per_group", {}) state_dict["gen"] = {"queue": []} - self.state = state_dict self.advanced_group.setChecked(wgp.advanced) - self.update_model_dropdowns(initial_model) self.refresh_ui_from_model_change(initial_model) - self._update_input_visibility() # Set initial visibility + self._update_input_visibility() def update_model_dropdowns(self, current_model_type): + wgp = self.wgp family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type) - - self.widgets['model_family'].blockSignals(True) - self.widgets['model_base_type_choice'].blockSignals(True) - self.widgets['model_choice'].blockSignals(True) - - self.widgets['model_family'].clear() - if family_mock.choices: - for display_name, internal_key in family_mock.choices: - self.widgets['model_family'].addItem(display_name, internal_key) - index = self.widgets['model_family'].findData(family_mock.value) - if index != -1: self.widgets['model_family'].setCurrentIndex(index) - - self.widgets['model_base_type_choice'].clear() - if base_type_mock.choices: - for label, value in base_type_mock.choices: - self.widgets['model_base_type_choice'].addItem(label, value) - index = self.widgets['model_base_type_choice'].findData(base_type_mock.value) - if index != -1: self.widgets['model_base_type_choice'].setCurrentIndex(index) - self.widgets['model_base_type_choice'].setVisible(base_type_mock.kwargs.get('visible', True)) - - self.widgets['model_choice'].clear() - if choice_mock.choices: - for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) - index = self.widgets['model_choice'].findData(choice_mock.value) - if index != -1: self.widgets['model_choice'].setCurrentIndex(index) - self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True)) - - self.widgets['model_family'].blockSignals(False) - self.widgets['model_base_type_choice'].blockSignals(False) - self.widgets['model_choice'].blockSignals(False) + for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]: + combo = self.widgets[combo_name] + combo.blockSignals(True) + combo.clear() + if mock.choices: + for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key) + index = combo.findData(mock.value) + if index != -1: combo.setCurrentIndex(index) + + is_visible = True + if hasattr(mock, 'kwargs') and isinstance(mock.kwargs, dict): + is_visible = mock.kwargs.get('visible', True) + elif hasattr(mock, 'visible'): + is_visible = mock.visible + combo.setVisible(is_visible) + + combo.blockSignals(False) def refresh_ui_from_model_change(self, model_type): """Update UI controls with default settings when the model is changed.""" + wgp = self.wgp self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode)) ui_defaults = wgp.get_default_settings(model_type) wgp.set_model_settings(self.state, model_type, ui_defaults) @@ -1128,10 +855,8 @@ def refresh_ui_from_model_change(self, model_type): self.widgets['resolution_group'].blockSignals(False) self.widgets['resolution'].blockSignals(False) - - for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']: - if name in self.widgets: - self.widgets[name].clear() + for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'image_refs', 'audio_source']: + if name in self.widgets: self.widgets[name].clear() guidance_layout = self.widgets['guidance_layout'] guidance_max = model_def.get("guidance_max_phases", 1) @@ -1155,7 +880,6 @@ def refresh_ui_from_model_change(self, model_type): misc_layout = self.widgets['misc_layout'] misc_layout.setRowVisible(self.widgets['riflex_row_index'], not (recammaster or ltxv or diffusion_forcing)) - index = self.widgets['multi_images_gen_type'].findData(ui_defaults.get('multi_images_gen_type', 0)) if index != -1: self.widgets['multi_images_gen_type'].setCurrentIndex(index) @@ -1170,12 +894,11 @@ def refresh_ui_from_model_change(self, model_type): guidance3_val = ui_defaults.get("guidance3_scale", 5.0) self.widgets['guidance3_scale'].setValue(int(guidance3_val * 10)) self.widgets['guidance3_scale_label'].setText(f"{guidance3_val:.1f}") - self.widgets['guidance_phases'].clear() + self.widgets['guidance_phases'].clear() if guidance_max >= 1: self.widgets['guidance_phases'].addItem("One Phase", 1) if guidance_max >= 2: self.widgets['guidance_phases'].addItem("Two Phases", 2) if guidance_max >= 3: self.widgets['guidance_phases'].addItem("Three Phases", 3) - index = self.widgets['guidance_phases'].findData(ui_defaults.get("guidance_phases", 1)) if index != -1: self.widgets['guidance_phases'].setCurrentIndex(index) @@ -1227,9 +950,7 @@ def refresh_ui_from_model_change(self, model_type): selected_loras = ui_defaults.get('activated_loras', []) for i in range(lora_list_widget.count()): item = lora_list_widget.item(i) - is_selected = any(item.text() == os.path.basename(p) for p in selected_loras) - if is_selected: - item.setSelected(True) + if any(item.text() == os.path.basename(p) for p in selected_loras): item.setSelected(True) self.widgets['loras_multipliers'].setText(ui_defaults.get('loras_multipliers', '')) skip_cache_val = ui_defaults.get('skip_steps_cache_type', "") @@ -1329,11 +1050,11 @@ def refresh_ui_from_model_change(self, model_type): for widget in self.widgets.values(): if hasattr(widget, 'blockSignals'): widget.blockSignals(False) + self._update_dynamic_ui() self._update_input_visibility() def _update_dynamic_ui(self): - """Update UI visibility based on current selections.""" phases = self.widgets['guidance_phases'].currentData() or 1 guidance_layout = self.widgets['guidance_layout'] guidance_layout.setRowVisible(self.widgets['guidance2_row_index'], phases >= 2) @@ -1341,83 +1062,45 @@ def _update_dynamic_ui(self): guidance_layout.setRowVisible(self.widgets['switch_thresh_row_index'], phases >= 2) def _update_generation_mode_visibility(self, model_def): - """Shows/hides the main generation mode options based on the selected model.""" allowed = model_def.get("image_prompt_types_allowed", "") - choices = [] - if "T" in allowed or not allowed: - choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T")) - if "S" in allowed: - choices.append(("Start Video with Image", "S")) - if "V" in allowed: - choices.append(("Continue Video", "V")) - if "L" in allowed: - choices.append(("Continue Last Video", "L")) - - button_map = { - "T": self.widgets['mode_t'], - "S": self.widgets['mode_s'], - "V": self.widgets['mode_v'], - "L": self.widgets['mode_l'], - } - - for btn in button_map.values(): - btn.setVisible(False) - + if "T" in allowed or not allowed: choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T")) + if "S" in allowed: choices.append(("Start Video with Image", "S")) + if "V" in allowed: choices.append(("Continue Video", "V")) + if "L" in allowed: choices.append(("Continue Last Video", "L")) + button_map = { "T": self.widgets['mode_t'], "S": self.widgets['mode_s'], "V": self.widgets['mode_v'], "L": self.widgets['mode_l'] } + for btn in button_map.values(): btn.setVisible(False) allowed_values = [c[1] for c in choices] for label, value in choices: if value in button_map: btn = button_map[value] btn.setText(label) btn.setVisible(True) - - current_checked_value = None - for value, btn in button_map.items(): - if btn.isChecked(): - current_checked_value = value - break - - # If the currently selected mode is now hidden, reset to a visible default + current_checked_value = next((value for value, btn in button_map.items() if btn.isChecked()), None) if current_checked_value is None or not button_map[current_checked_value].isVisible(): - if allowed_values: - button_map[allowed_values[0]].setChecked(True) - - + if allowed_values: button_map[allowed_values[0]].setChecked(True) end_image_visible = "E" in allowed self.widgets['image_end_checkbox'].setVisible(end_image_visible) - if not end_image_visible: - self.widgets['image_end_checkbox'].setChecked(False) - - # Control Video Checkbox (Based on model_def.get("guide_preprocessing")) + if not end_image_visible: self.widgets['image_end_checkbox'].setChecked(False) control_video_visible = model_def.get("guide_preprocessing") is not None self.widgets['control_video_checkbox'].setVisible(control_video_visible) - if not control_video_visible: - self.widgets['control_video_checkbox'].setChecked(False) - - # Reference Image Checkbox (Based on model_def.get("image_ref_choices")) + if not control_video_visible: self.widgets['control_video_checkbox'].setChecked(False) ref_image_visible = model_def.get("image_ref_choices") is not None self.widgets['ref_image_checkbox'].setVisible(ref_image_visible) - if not ref_image_visible: - self.widgets['ref_image_checkbox'].setChecked(False) - + if not ref_image_visible: self.widgets['ref_image_checkbox'].setChecked(False) def _update_input_visibility(self): - """Shows/hides input fields based on the selected generation mode.""" is_s_mode = self.widgets['mode_s'].isChecked() is_v_mode = self.widgets['mode_v'].isChecked() is_l_mode = self.widgets['mode_l'].isChecked() - use_end = self.widgets['image_end_checkbox'].isChecked() and self.widgets['image_end_checkbox'].isVisible() use_control = self.widgets['control_video_checkbox'].isChecked() and self.widgets['control_video_checkbox'].isVisible() use_ref = self.widgets['ref_image_checkbox'].isChecked() and self.widgets['ref_image_checkbox'].isVisible() - self.widgets['image_start_container'].setVisible(is_s_mode) self.widgets['video_source_container'].setVisible(is_v_mode) - end_checkbox_enabled = is_s_mode or is_v_mode or is_l_mode self.widgets['image_end_checkbox'].setEnabled(end_checkbox_enabled) self.widgets['image_end_container'].setVisible(use_end and end_checkbox_enabled) - self.widgets['video_guide_container'].setVisible(use_control) self.widgets['video_mask_container'].setVisible(use_control) self.widgets['image_refs_container'].setVisible(use_ref) @@ -1428,17 +1111,14 @@ def connect_signals(self): self.widgets['model_choice'].currentIndexChanged.connect(self._on_model_changed) self.widgets['resolution_group'].currentIndexChanged.connect(self._on_resolution_group_changed) self.widgets['guidance_phases'].currentIndexChanged.connect(self._update_dynamic_ui) - self.widgets['mode_t'].toggled.connect(self._update_input_visibility) self.widgets['mode_s'].toggled.connect(self._update_input_visibility) self.widgets['mode_v'].toggled.connect(self._update_input_visibility) self.widgets['mode_l'].toggled.connect(self._update_input_visibility) - self.widgets['image_end_checkbox'].toggled.connect(self._update_input_visibility) self.widgets['control_video_checkbox'].toggled.connect(self._update_input_visibility) self.widgets['ref_image_checkbox'].toggled.connect(self._update_input_visibility) self.widgets['preview_group'].toggled.connect(self._on_preview_toggled) - self.generate_btn.clicked.connect(self._on_generate) self.add_to_queue_btn.clicked.connect(self._on_add_to_queue) self.remove_queue_btn.clicked.connect(self._on_remove_selected_from_queue) @@ -1446,214 +1126,169 @@ def connect_signals(self): self.abort_btn.clicked.connect(self._on_abort) self.queue_table.rowsMoved.connect(self._on_queue_rows_moved) + def load_main_config(self): + try: + with open('main_config.json', 'r') as f: self.main_config = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): self.main_config = {'preview_visible': False} + + def save_main_config(self): + try: + with open('main_config.json', 'w') as f: json.dump(self.main_config, f, indent=4) + except Exception as e: print(f"Error saving main_config.json: {e}") + def apply_initial_config(self): - is_visible = main_config.get('preview_visible', True) + is_visible = self.main_config.get('preview_visible', True) self.widgets['preview_group'].setChecked(is_visible) self.widgets['preview_image'].setVisible(is_visible) def _on_preview_toggled(self, checked): self.widgets['preview_image'].setVisible(checked) - main_config['preview_visible'] = checked - save_main_config() + self.main_config['preview_visible'] = checked + self.save_main_config() - def _on_family_changed(self, index): + def _on_family_changed(self): family = self.widgets['model_family'].currentData() if not family or not self.state: return - - base_type_mock, choice_mock = wgp.change_model_family(self.state, family) + base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family) + if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict): + is_visible_base = base_type_mock.kwargs.get('visible', True) + elif hasattr(base_type_mock, 'visible'): + is_visible_base = base_type_mock.visible + else: + is_visible_base = True + self.widgets['model_base_type_choice'].blockSignals(True) self.widgets['model_base_type_choice'].clear() if base_type_mock.choices: - for label, value in base_type_mock.choices: - self.widgets['model_base_type_choice'].addItem(label, value) - index = self.widgets['model_base_type_choice'].findData(base_type_mock.value) - if index != -1: - self.widgets['model_base_type_choice'].setCurrentIndex(index) - self.widgets['model_base_type_choice'].setVisible(base_type_mock.kwargs.get('visible', True)) + for label, value in base_type_mock.choices: self.widgets['model_base_type_choice'].addItem(label, value) + self.widgets['model_base_type_choice'].setCurrentIndex(self.widgets['model_base_type_choice'].findData(base_type_mock.value)) + self.widgets['model_base_type_choice'].setVisible(is_visible_base) self.widgets['model_base_type_choice'].blockSignals(False) - + + if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict): + is_visible_choice = choice_mock.kwargs.get('visible', True) + elif hasattr(choice_mock, 'visible'): + is_visible_choice = choice_mock.visible + else: + is_visible_choice = True + self.widgets['model_choice'].blockSignals(True) self.widgets['model_choice'].clear() if choice_mock.choices: - for label, value in choice_mock.choices: - self.widgets['model_choice'].addItem(label, value) - index = self.widgets['model_choice'].findData(choice_mock.value) - if index != -1: - self.widgets['model_choice'].setCurrentIndex(index) - self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True)) + for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) + self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value)) + self.widgets['model_choice'].setVisible(is_visible_choice) self.widgets['model_choice'].blockSignals(False) - + self._on_model_changed() - def _on_base_type_changed(self, index): + def _on_base_type_changed(self): family = self.widgets['model_family'].currentData() base_type = self.widgets['model_base_type_choice'].currentData() if not family or not base_type or not self.state: return - - base_type_mock, choice_mock = wgp.change_model_base_types(self.state, family, base_type) - + base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type) + + if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict): + is_visible_choice = choice_mock.kwargs.get('visible', True) + elif hasattr(choice_mock, 'visible'): + is_visible_choice = choice_mock.visible + else: + is_visible_choice = True + self.widgets['model_choice'].blockSignals(True) self.widgets['model_choice'].clear() if choice_mock.choices: - for label, value in choice_mock.choices: - self.widgets['model_choice'].addItem(label, value) - index = self.widgets['model_choice'].findData(choice_mock.value) - if index != -1: - self.widgets['model_choice'].setCurrentIndex(index) - self.widgets['model_choice'].setVisible(choice_mock.kwargs.get('visible', True)) + for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) + self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value)) + self.widgets['model_choice'].setVisible(is_visible_choice) self.widgets['model_choice'].blockSignals(False) - self._on_model_changed() def _on_model_changed(self): model_type = self.widgets['model_choice'].currentData() - if not model_type or model_type == self.state['model_type']: return - wgp.change_model(self.state, model_type) + if not model_type or model_type == self.state.get('model_type'): return + self.wgp.change_model(self.state, model_type) self.refresh_ui_from_model_change(model_type) def _on_resolution_group_changed(self): selected_group = self.widgets['resolution_group'].currentText() - if not selected_group or not hasattr(self, 'full_resolution_choices'): - return - + if not selected_group or not hasattr(self, 'full_resolution_choices'): return model_type = self.state['model_type'] - model_def = wgp.get_model_def(model_type) + model_def = self.wgp.get_model_def(model_type) model_resolutions = model_def.get("resolutions", None) - group_resolution_choices = [] if model_resolutions is None: - group_resolution_choices = [res for res in self.full_resolution_choices if wgp.categorize_resolution(res[1]) == selected_group] - else: - return - - last_resolution_per_group = self.state.get("last_resolution_per_group", {}) - last_resolution = last_resolution_per_group.get(selected_group, "") - - is_last_res_valid = any(last_resolution == res[1] for res in group_resolution_choices) - if not is_last_res_valid and group_resolution_choices: + group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group] + else: return + last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "") + if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices: last_resolution = group_resolution_choices[0][1] - self.widgets['resolution'].blockSignals(True) self.widgets['resolution'].clear() - for label, value in group_resolution_choices: - self.widgets['resolution'].addItem(label, value) - - index = self.widgets['resolution'].findData(last_resolution) - if index != -1: - self.widgets['resolution'].setCurrentIndex(index) + for label, value in group_resolution_choices: self.widgets['resolution'].addItem(label, value) + self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution)) self.widgets['resolution'].blockSignals(False) def collect_inputs(self): - """Gather all settings from UI widgets into a dictionary.""" - # Start with all possible defaults. This dictionary will be modified and returned. - full_inputs = wgp.get_current_model_settings(self.state).copy() - - # Add dummy/default values for UI elements present in Gradio but not yet in PyQt. - # These are expected by the backend logic. + full_inputs = self.wgp.get_current_model_settings(self.state).copy() full_inputs['lset_name'] = "" - # The PyQt UI is focused on generating videos, so image_mode is 0. full_inputs['image_mode'] = 0 - - # Defensively initialize keys that are accessed directly in wgp.py but may not - # be in the saved model settings or fully implemented in the UI yet. - # This prevents both KeyErrors and TypeErrors for missing arguments. - expected_keys = { - "audio_guide": None, "audio_guide2": None, "image_guide": None, - "image_mask": None, "speakers_locations": "", "frames_positions": "", - "keep_frames_video_guide": "", "keep_frames_video_source": "", - "video_guide_outpainting": "", "switch_threshold2": 0, - "model_switch_phase": 1, "batch_size": 1, - "control_net_weight_alt": 1.0, - "image_refs_relative_size": 50, - } + expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, } for key, default_value in expected_keys.items(): - if key not in full_inputs: - full_inputs[key] = default_value - - # Overwrite defaults with values from the PyQt UI widgets + if key not in full_inputs: full_inputs[key] = default_value full_inputs['prompt'] = self.widgets['prompt'].toPlainText() full_inputs['negative_prompt'] = self.widgets['negative_prompt'].toPlainText() full_inputs['resolution'] = self.widgets['resolution'].currentData() full_inputs['video_length'] = self.widgets['video_length'].value() full_inputs['num_inference_steps'] = self.widgets['num_inference_steps'].value() full_inputs['seed'] = int(self.widgets['seed'].text()) - - # Build prompt_type strings based on mode selections image_prompt_type = "" video_prompt_type = "" - - if self.widgets['mode_s'].isChecked(): - image_prompt_type = 'S' - elif self.widgets['mode_v'].isChecked(): - image_prompt_type = 'V' - elif self.widgets['mode_l'].isChecked(): - image_prompt_type = 'L' - else: # mode_t is checked - image_prompt_type = '' - - if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): - image_prompt_type += 'E' - - if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): - video_prompt_type += 'V' # This 'V' is for Control Video (V2V) - if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): - video_prompt_type += 'I' # 'I' for Reference Image - + if self.widgets['mode_s'].isChecked(): image_prompt_type = 'S' + elif self.widgets['mode_v'].isChecked(): image_prompt_type = 'V' + elif self.widgets['mode_l'].isChecked(): image_prompt_type = 'L' + if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): image_prompt_type += 'E' + if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): video_prompt_type += 'V' + if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): video_prompt_type += 'I' full_inputs['image_prompt_type'] = image_prompt_type full_inputs['video_prompt_type'] = video_prompt_type - - # File Inputs for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']: - if name in self.widgets: - path = self.widgets[name].text() - full_inputs[name] = path if path else None - + if name in self.widgets: full_inputs[name] = self.widgets[name].text() or None paths = self.widgets['image_refs'].text().split(';') full_inputs['image_refs'] = [p.strip() for p in paths if p.strip()] if paths and paths[0] else None - full_inputs['denoising_strength'] = self.widgets['denoising_strength'].value() / 100.0 - if self.advanced_group.isChecked(): full_inputs['guidance_scale'] = self.widgets['guidance_scale'].value() / 10.0 full_inputs['guidance_phases'] = self.widgets['guidance_phases'].currentData() full_inputs['guidance2_scale'] = self.widgets['guidance2_scale'].value() / 10.0 full_inputs['guidance3_scale'] = self.widgets['guidance3_scale'].value() / 10.0 full_inputs['switch_threshold'] = self.widgets['switch_threshold'].value() - full_inputs['NAG_scale'] = self.widgets['NAG_scale'].value() / 10.0 full_inputs['NAG_tau'] = self.widgets['NAG_tau'].value() / 10.0 full_inputs['NAG_alpha'] = self.widgets['NAG_alpha'].value() / 10.0 - full_inputs['sample_solver'] = self.widgets['sample_solver'].currentData() full_inputs['flow_shift'] = self.widgets['flow_shift'].value() / 10.0 full_inputs['audio_guidance_scale'] = self.widgets['audio_guidance_scale'].value() / 10.0 full_inputs['repeat_generation'] = self.widgets['repeat_generation'].value() full_inputs['multi_images_gen_type'] = self.widgets['multi_images_gen_type'].currentData() - - lora_list_widget = self.widgets['activated_loras'] - selected_items = lora_list_widget.selectedItems() + selected_items = self.widgets['activated_loras'].selectedItems() full_inputs['activated_loras'] = [self.lora_map[item.text()] for item in selected_items if item.text() in self.lora_map] full_inputs['loras_multipliers'] = self.widgets['loras_multipliers'].toPlainText() - full_inputs['skip_steps_cache_type'] = self.widgets['skip_steps_cache_type'].currentData() full_inputs['skip_steps_multiplier'] = self.widgets['skip_steps_multiplier'].currentData() full_inputs['skip_steps_start_step_perc'] = self.widgets['skip_steps_start_step_perc'].value() - full_inputs['temporal_upsampling'] = self.widgets['temporal_upsampling'].currentData() full_inputs['spatial_upsampling'] = self.widgets['spatial_upsampling'].currentData() full_inputs['film_grain_intensity'] = self.widgets['film_grain_intensity'].value() / 100.0 full_inputs['film_grain_saturation'] = self.widgets['film_grain_saturation'].value() / 100.0 - full_inputs['MMAudio_setting'] = self.widgets['MMAudio_setting'].currentData() full_inputs['MMAudio_prompt'] = self.widgets['MMAudio_prompt'].text() full_inputs['MMAudio_neg_prompt'] = self.widgets['MMAudio_neg_prompt'].text() - full_inputs['RIFLEx_setting'] = self.widgets['RIFLEx_setting'].currentData() full_inputs['force_fps'] = self.widgets['force_fps'].currentData() full_inputs['override_profile'] = self.widgets['override_profile'].currentData() full_inputs['multi_prompts_gen_type'] = self.widgets['multi_prompts_gen_type'].currentData() - full_inputs['slg_switch'] = self.widgets['slg_switch'].currentData() full_inputs['slg_start_perc'] = self.widgets['slg_start_perc'].value() full_inputs['slg_end_perc'] = self.widgets['slg_end_perc'].value() @@ -1661,13 +1296,11 @@ def collect_inputs(self): full_inputs['cfg_star_switch'] = self.widgets['cfg_star_switch'].currentData() full_inputs['cfg_zero_step'] = self.widgets['cfg_zero_step'].value() full_inputs['min_frames_if_references'] = self.widgets['min_frames_if_references'].currentData() - full_inputs['sliding_window_size'] = self.widgets['sliding_window_size'].value() full_inputs['sliding_window_overlap'] = self.widgets['sliding_window_overlap'].value() full_inputs['sliding_window_color_correction_strength'] = self.widgets['sliding_window_color_correction_strength'].value() / 100.0 full_inputs['sliding_window_overlap_noise'] = self.widgets['sliding_window_overlap_noise'].value() full_inputs['sliding_window_discard_last_frames'] = self.widgets['sliding_window_discard_last_frames'].value() - return full_inputs def _prepare_state_for_generation(self): @@ -1678,62 +1311,44 @@ def _prepare_state_for_generation(self): def _on_generate(self): try: is_running = self.thread and self.thread.isRunning() - self._add_task_to_queue_and_update_ui() - if not is_running: - self.start_generation() - except Exception: - import traceback - traceback.print_exc() - + self._add_task_to_queue() + if not is_running: self.start_generation() + except Exception as e: + import traceback; traceback.print_exc() def _on_add_to_queue(self): try: - self._add_task_to_queue_and_update_ui() - except Exception: - import traceback - traceback.print_exc() - - def _add_task_to_queue_and_update_ui(self): - self._add_task_to_queue() - self.update_queue_table() - + self._add_task_to_queue() + except Exception as e: + import traceback; traceback.print_exc() + def _add_task_to_queue(self): - queue_size_before = len(self.state["gen"]["queue"]) all_inputs = self.collect_inputs() - keys_to_remove = ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type'] - for key in keys_to_remove: - all_inputs.pop(key, None) - + for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None) all_inputs['state'] = self.state - wgp.set_model_settings(self.state, self.state['model_type'], all_inputs) - + self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs) self.state["validate_success"] = 1 - wgp.process_prompt_and_add_tasks(self.state, self.state['model_type']) - + self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type']) + self.update_queue_table() def start_generation(self): - if not self.state['gen']['queue']: - return + if not self.state['gen']['queue']: return self._prepare_state_for_generation() self.generate_btn.setEnabled(False) self.add_to_queue_btn.setEnabled(True) - self.thread = QThread() - self.worker = Worker(self.state) + self.worker = Worker(self.plugin, self.state) self.worker.moveToThread(self.thread) - self.thread.started.connect(self.worker.run) self.worker.finished.connect(self.thread.quit) self.worker.finished.connect(self.worker.deleteLater) self.thread.finished.connect(self.thread.deleteLater) self.thread.finished.connect(self.on_generation_finished) - self.worker.status.connect(self.status_label.setText) self.worker.progress.connect(self.update_progress) self.worker.preview.connect(self.update_preview) - self.worker.output.connect(self.update_queue_and_gallery) + self.worker.output.connect(self.update_queue_and_results) self.worker.error.connect(self.on_generation_error) - self.thread.start() self.update_queue_table() @@ -1743,17 +1358,11 @@ def on_generation_finished(self): self.progress_bar.setValue(0) self.generate_btn.setEnabled(True) self.add_to_queue_btn.setEnabled(False) - self.thread = None - self.worker = None + self.thread = None; self.worker = None self.update_queue_table() def on_generation_error(self, err_msg): - msg_box = QMessageBox() - msg_box.setIcon(QMessageBox.Icon.Critical) - msg_box.setText("Generation Error") - msg_box.setInformativeText(str(err_msg)) - msg_box.setWindowTitle("Error") - msg_box.exec() + QMessageBox.critical(self, "Generation Error", str(err_msg)) self.on_generation_finished() def update_progress(self, data): @@ -1762,221 +1371,283 @@ def update_progress(self, data): self.progress_bar.setMaximum(total) self.progress_bar.setValue(step) self.status_label.setText(str(data[1])) - if step <= 1: - self.update_queue_table() - elif len(data) > 1: - self.status_label.setText(str(data[1])) + if step <= 1: self.update_queue_table() + elif len(data) > 1: self.status_label.setText(str(data[1])) def update_preview(self, pil_image): - if pil_image: + if pil_image and self.widgets['preview_group'].isChecked(): q_image = ImageQt(pil_image) pixmap = QPixmap.fromImage(q_image) - self.preview_image.setPixmap(pixmap.scaled( - self.preview_image.size(), - Qt.AspectRatioMode.KeepAspectRatio, - Qt.TransformationMode.SmoothTransformation - )) + self.preview_image.setPixmap(pixmap.scaled(self.preview_image.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)) - def update_queue_and_gallery(self): + def update_queue_and_results(self): self.update_queue_table() - file_list = self.state.get('gen', {}).get('file_list', []) - self.output_gallery.clear() - self.output_gallery.addItems(file_list) - if file_list: - self.output_gallery.setCurrentRow(len(file_list) - 1) - self.latest_output_path = file_list[-1] # <-- ADDED for API access + for file_path in file_list: + if file_path not in self.processed_files: + self.add_result_item(file_path) + self.processed_files.add(file_path) + + def add_result_item(self, video_path): + item_widget = VideoResultItemWidget(video_path, self.plugin) + list_item = QListWidgetItem(self.results_list) + list_item.setSizeHint(item_widget.sizeHint()) + self.results_list.addItem(list_item) + self.results_list.setItemWidget(list_item, item_widget) def update_queue_table(self): - with wgp.lock: + with self.wgp.lock: queue = self.state.get('gen', {}).get('queue', []) is_running = self.thread and self.thread.isRunning() queue_to_display = queue if is_running else [None] + queue - - table_data = wgp.get_queue_table(queue_to_display) - + table_data = self.wgp.get_queue_table(queue_to_display) self.queue_table.setRowCount(0) self.queue_table.setRowCount(len(table_data)) self.queue_table.setColumnCount(4) self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"]) - for row_idx, row_data in enumerate(table_data): - prompt_html = row_data[1] - try: - prompt_text = prompt_html.split('>')[1].split('<')[0] - except IndexError: - prompt_text = str(row_data[1]) - - self.queue_table.setItem(row_idx, 0, QTableWidgetItem(str(row_data[0]))) - self.queue_table.setItem(row_idx, 1, QTableWidgetItem(prompt_text)) - self.queue_table.setItem(row_idx, 2, QTableWidgetItem(str(row_data[2]))) - self.queue_table.setItem(row_idx, 3, QTableWidgetItem(str(row_data[3]))) - + prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1]) + for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]): + self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data))) self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) self.queue_table.resizeColumnsToContents() def _on_remove_selected_from_queue(self): selected_row = self.queue_table.currentRow() - if selected_row < 0: - return - - with wgp.lock: + if selected_row < 0: return + with self.wgp.lock: is_running = self.thread and self.thread.isRunning() offset = 1 if is_running else 0 - queue = self.state.get('gen', {}).get('queue', []) - if len(queue) > selected_row + offset: - queue.pop(selected_row + offset) + if len(queue) > selected_row + offset: queue.pop(selected_row + offset) self.update_queue_table() def _on_queue_rows_moved(self, source_row, dest_row): - with wgp.lock: + with self.wgp.lock: queue = self.state.get('gen', {}).get('queue', []) is_running = self.thread and self.thread.isRunning() offset = 1 if is_running else 0 - real_source_idx = source_row + offset real_dest_idx = dest_row + offset - moved_item = queue.pop(real_source_idx) queue.insert(real_dest_idx, moved_item) self.update_queue_table() def _on_clear_queue(self): - wgp.clear_queue_action(self.state) + self.wgp.clear_queue_action(self.state) self.update_queue_table() def _on_abort(self): if self.worker: - wgp.abort_generation(self.state) + self.wgp.abort_generation(self.state) self.status_label.setText("Aborting...") self.worker._is_running = False def _on_release_ram(self): - wgp.release_RAM() + self.wgp.release_RAM() QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.") def _on_apply_config_changes(self): changes = {} - list_widget = self.widgets['config_transformer_types'] - checked_items = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] - changes['transformer_types_choices'] = checked_items - + changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] list_widget = self.widgets['config_preload_model_policy'] - checked_items = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] - changes['preload_model_policy_choice'] = checked_items - + changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData() changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText() - for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]: changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]: changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value() - for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]: changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]: widget = self.widgets[f'config_{key}'] changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text() - changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData() changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value() - changes['last_resolution_choice'] = self.widgets['resolution'].currentData() - try: - msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = wgp.apply_changes(self.state, **changes) + msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes) self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.") - - self.header_info.setText(wgp.generate_header(self.state['model_type'], wgp.compile, wgp.attention_mode)) - + self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode)) if family_mock.choices is not None or choice_mock.choices is not None: - self.update_model_dropdowns(wgp.transformer_type) - self.refresh_ui_from_model_change(wgp.transformer_type) - + self.update_model_dropdowns(self.wgp.transformer_type) + self.refresh_ui_from_model_change(self.wgp.transformer_type) except Exception as e: self.config_status_label.setText(f"Error applying changes: {e}") - import traceback - traceback.print_exc() - - def closeEvent(self, event): - if wgp: - wgp.autosave_queue() - save_main_config() - event.accept() - -# ===================================================================== -# --- START OF API SERVER ADDITION --- -# ===================================================================== -try: - from flask import Flask, request, jsonify - FLASK_AVAILABLE = True -except ImportError: - FLASK_AVAILABLE = False - print("Flask not installed. API server will not be available. Please run: pip install Flask") - -# Global reference to the main window, to be populated after instantiation. -main_window_instance = None -api_server = Flask(__name__) if FLASK_AVAILABLE else None - -def run_api_server(): - """Function to run the Flask server.""" - if api_server: - print("Starting API server on http://127.0.0.1:5100") - api_server.run(port=5100, host='127.0.0.1', debug=False) - -if FLASK_AVAILABLE: - - @api_server.route('/api/generate', methods=['POST']) - def generate(): - if not main_window_instance: - return jsonify({"error": "Application not ready"}), 503 - - data = request.json - start_frame = data.get('start_frame') - end_frame = data.get('end_frame') - # --- CHANGE: Get duration_sec and model_type --- - duration_sec = data.get('duration_sec') - model_type = data.get('model_type') - start_generation = data.get('start_generation', False) - - # --- CHANGE: Emit the signal with the new parameters --- - main_window_instance.api_bridge.generateSignal.emit( - start_frame, end_frame, duration_sec, model_type, start_generation - ) - - if start_generation: - return jsonify({"message": "Parameters set and generation request sent."}) + import traceback; traceback.print_exc() + +class Plugin(VideoEditorPlugin): + def initialize(self): + self.name = "WanGP AI Generator" + self.description = "Uses the integrated WanGP library to generate video clips." + self.client_widget = WgpDesktopPluginWidget(self) + self.dock_widget = None + self.active_region = None; self.temp_dir = None + self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None + + def enable(self): + if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name) + self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu) + self.app.status_label.setText(f"{self.name}: Enabled.") + + def disable(self): + try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu) + except TypeError: pass + self._cleanup_temp_dir() + if self.client_widget.worker: self.client_widget._on_abort() + self.app.status_label.setText(f"{self.name}: Disabled.") + + def _cleanup_temp_dir(self): + if self.temp_dir and os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + self.temp_dir = None + + def _reset_state(self): + self.active_region = None; self.insert_on_new_track = False + self.start_frame_path = None; self.end_frame_path = None + self.client_widget.processed_files.clear() + self.client_widget.results_list.clear() + self.client_widget.widgets['image_start'].clear() + self.client_widget.widgets['image_end'].clear() + self.client_widget.widgets['video_source'].clear() + self._cleanup_temp_dir() + self.app.status_label.setText(f"{self.name}: Ready.") + + def on_timeline_context_menu(self, menu, event): + region = self.app.timeline_widget.get_region_at_pos(event.pos()) + if region: + menu.addSeparator() + start_sec, end_sec = region + start_data, _, _ = self.app.get_frame_data_at_time(start_sec) + end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + if start_data and end_data: + join_action = menu.addAction("Join Frames With WanGP") + join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) + join_action_new_track = menu.addAction("Join Frames With WanGP (New Track)") + join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) + create_action = menu.addAction("Create Video With WanGP") + create_action.triggered.connect(lambda: self.setup_creator_for_region(region)) + + def setup_generator_for_region(self, region, on_new_track=False): + self._reset_state() + self.active_region = region + self.insert_on_new_track = on_new_track + + model_to_set = 'i2v_2_2' + dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types + _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + if any(model_to_set == m[1] for m in all_models): + if self.client_widget.state.get('model_type') != model_to_set: + self.client_widget.update_model_dropdowns(model_to_set) + self.client_widget._on_model_changed() else: - return jsonify({"message": "Parameters set without starting generation."}) - - - @api_server.route('/api/outputs', methods=['GET']) - def get_outputs(): - if not main_window_instance: - return jsonify({"error": "Application not ready"}), 503 - - file_list = main_window_instance.state.get('gen', {}).get('file_list', []) - return jsonify({"outputs": file_list}) - -# ===================================================================== -# --- END OF API SERVER ADDITION --- -# ===================================================================== - -if __name__ == '__main__': - app = QApplication(sys.argv) - - window = MainWindow() - main_window_instance = window - window.show() + print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") + + start_sec, end_sec = region + start_data, w, h = self.app.get_frame_data_at_time(start_sec) + end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + if not start_data or not end_data: + QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.") + return + try: + self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_") + self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png") + self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png") + QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path) + QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) + + duration_sec = end_sec - start_sec + wgp = self.client_widget.wgp + model_type = self.client_widget.state['model_type'] + fps = wgp.get_model_fps(model_type) + video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + + self.client_widget.widgets['video_length'].setValue(video_length_frames) + self.client_widget.widgets['mode_s'].setChecked(True) + self.client_widget.widgets['image_end_checkbox'].setChecked(True) + self.client_widget.widgets['image_start'].setText(self.start_frame_path) + self.client_widget.widgets['image_end'].setText(self.end_frame_path) + + self.client_widget._update_input_visibility() - if FLASK_AVAILABLE: - api_thread = threading.Thread(target=run_api_server, daemon=True) - api_thread.start() + except Exception as e: + QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}") + self._cleanup_temp_dir() + return + self.app.status_label.setText(f"WanGP: Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.") + self.dock_widget.show() + self.dock_widget.raise_() + + def setup_creator_for_region(self, region): + self._reset_state() + self.active_region = region + self.insert_on_new_track = True + + model_to_set = 't2v_2_2' + dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types + _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + if any(model_to_set == m[1] for m in all_models): + if self.client_widget.state.get('model_type') != model_to_set: + self.client_widget.update_model_dropdowns(model_to_set) + self.client_widget._on_model_changed() + else: + print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.") - sys.exit(app.exec()) \ No newline at end of file + start_sec, end_sec = region + duration_sec = end_sec - start_sec + wgp = self.client_widget.wgp + model_type = self.client_widget.state['model_type'] + fps = wgp.get_model_fps(model_type) + video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + + self.client_widget.widgets['video_length'].setValue(video_length_frames) + self.client_widget.widgets['mode_t'].setChecked(True) + + self.app.status_label.setText(f"WanGP: Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.") + self.dock_widget.show() + self.dock_widget.raise_() + + def insert_generated_clip(self, video_path): + from videoeditor import TimelineClip + if not self.active_region: + self.app.status_label.setText("WanGP Error: No active region to insert into."); return + if not os.path.exists(video_path): + self.app.status_label.setText(f"WanGP Error: Output file not found: {video_path}"); return + start_sec, end_sec = self.active_region + def complex_insertion_action(): + self.app._add_media_files_to_project([video_path]) + media_info = self.app.media_properties.get(video_path) + if not media_info: raise ValueError("Could not probe inserted clip.") + actual_duration, has_audio = media_info['duration'], media_info['has_audio'] + if self.insert_on_new_track: + self.app.timeline.num_video_tracks += 1 + video_track_index = self.app.timeline.num_video_tracks + audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None + else: + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) + clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec] + for clip in clips_to_remove: + if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip) + video_track_index, audio_track_index = 1, 1 if has_audio else None + group_id = str(uuid.uuid4()) + new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id) + self.app.timeline.add_clip(new_clip) + if audio_track_index: + if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index + audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id) + self.app.timeline.add_clip(audio_clip) + try: + self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action) + self.app.prune_empty_tracks() + self.app.status_label.setText("AI clip inserted successfully.") + for i in range(self.client_widget.results_list.count()): + widget = self.client_widget.results_list.itemWidget(self.client_widget.results_list.item(i)) + if widget and widget.video_path == video_path: + self.client_widget.results_list.takeItem(i); break + except Exception as e: + import traceback; traceback.print_exc() + self.app.status_label.setText(f"WanGP Error during clip insertion: {e}") \ No newline at end of file diff --git a/videoeditor/undo.py b/undo.py similarity index 100% rename from videoeditor/undo.py rename to undo.py diff --git a/videoeditor/main.py b/videoeditor.py similarity index 99% rename from videoeditor/main.py rename to videoeditor.py index 4bd3662c6..1c7466957 100644 --- a/videoeditor/main.py +++ b/videoeditor.py @@ -1,3 +1,4 @@ +import wgp import sys import os import uuid @@ -1288,7 +1289,7 @@ def __init__(self, project_to_load=None): self.is_shutting_down = False self._load_settings() - self.plugin_manager = PluginManager(self) + self.plugin_manager = PluginManager(self, wgp) self.plugin_manager.discover_and_load_plugins() self.project_fps = 25.0 diff --git a/videoeditor/plugins/ai_frame_joiner/main.py b/videoeditor/plugins/ai_frame_joiner/main.py deleted file mode 100644 index 410333d28..000000000 --- a/videoeditor/plugins/ai_frame_joiner/main.py +++ /dev/null @@ -1,512 +0,0 @@ -import sys -import os -import tempfile -import shutil -import requests -from pathlib import Path -import ffmpeg -import uuid - -from PyQt6.QtWidgets import ( - QWidget, QVBoxLayout, QHBoxLayout, QFormLayout, - QLineEdit, QPushButton, QLabel, QMessageBox, QCheckBox, QListWidget, - QListWidgetItem, QGroupBox -) -from PyQt6.QtGui import QImage, QPixmap -from PyQt6.QtCore import pyqtSignal, Qt, QSize, QRectF, QUrl, QTimer -from PyQt6.QtMultimedia import QMediaPlayer -from PyQt6.QtMultimediaWidgets import QVideoWidget - -sys.path.append(str(Path(__file__).parent.parent.parent)) -from plugins import VideoEditorPlugin - -API_BASE_URL = "http://127.0.0.1:5100" - -class VideoResultItemWidget(QWidget): - def __init__(self, video_path, plugin, parent=None): - super().__init__(parent) - self.video_path = video_path - self.plugin = plugin - self.app = plugin.app - self.duration = 0.0 - self.has_audio = False - - self.setMinimumSize(200, 180) - self.setMaximumHeight(190) - - layout = QVBoxLayout(self) - layout.setContentsMargins(5, 5, 5, 5) - layout.setSpacing(5) - - self.media_player = QMediaPlayer() - self.video_widget = QVideoWidget() - self.video_widget.setFixedSize(160, 90) - self.media_player.setVideoOutput(self.video_widget) - self.media_player.setSource(QUrl.fromLocalFile(self.video_path)) - self.media_player.setLoops(QMediaPlayer.Loops.Infinite) - - self.info_label = QLabel(os.path.basename(video_path)) - self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.info_label.setWordWrap(True) - - self.insert_button = QPushButton("Insert into Timeline") - self.insert_button.clicked.connect(self.on_insert) - - h_layout = QHBoxLayout() - h_layout.addStretch() - h_layout.addWidget(self.video_widget) - h_layout.addStretch() - - layout.addLayout(h_layout) - layout.addWidget(self.info_label) - layout.addWidget(self.insert_button) - - self.probe_video() - - def probe_video(self): - try: - probe = ffmpeg.probe(self.video_path) - self.duration = float(probe['format']['duration']) - self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', [])) - - self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)") - - except Exception as e: - self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}") - print(f"Error probing video {self.video_path}: {e}") - - def enterEvent(self, event): - super().enterEvent(event) - self.media_player.play() - - if not self.plugin.active_region or self.duration == 0: - return - - start_sec, _ = self.plugin.active_region - timeline = self.app.timeline_widget - - video_rect, audio_rect = None, None - - x = timeline.sec_to_x(start_sec) - w = int(self.duration * timeline.pixels_per_second) - - if self.plugin.insert_on_new_track: - video_y = timeline.TIMESCALE_HEIGHT - video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) - if self.has_audio: - audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT - audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) - else: - v_track_idx = 1 - visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx - video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT - video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) - if self.has_audio: - a_track_idx = 1 - visual_a_idx = a_track_idx - 1 - audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT - audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) - - timeline.set_hover_preview_rects(video_rect, audio_rect) - - def leaveEvent(self, event): - super().leaveEvent(event) - self.media_player.pause() - self.media_player.setPosition(0) - self.app.timeline_widget.set_hover_preview_rects(None, None) - - def on_insert(self): - self.media_player.stop() - self.media_player.setSource(QUrl()) - self.media_player.setVideoOutput(None) - self.app.timeline_widget.set_hover_preview_rects(None, None) - self.plugin.insert_generated_clip(self.video_path) - -class WgpClientWidget(QWidget): - status_updated = pyqtSignal(str) - - def __init__(self, plugin): - super().__init__() - self.plugin = plugin - self.processed_files = set() - - layout = QVBoxLayout(self) - form_layout = QFormLayout() - - self.previews_widget = QWidget() - previews_layout = QHBoxLayout(self.previews_widget) - previews_layout.setContentsMargins(0, 0, 0, 0) - - start_preview_layout = QVBoxLayout() - start_preview_layout.addWidget(QLabel("Start Frame")) - self.start_frame_preview = QLabel("N/A") - self.start_frame_preview.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.start_frame_preview.setFixedSize(160, 90) - self.start_frame_preview.setStyleSheet("background-color: #222; color: #888;") - start_preview_layout.addWidget(self.start_frame_preview) - previews_layout.addLayout(start_preview_layout) - - end_preview_layout = QVBoxLayout() - end_preview_layout.addWidget(QLabel("End Frame")) - self.end_frame_preview = QLabel("N/A") - self.end_frame_preview.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.end_frame_preview.setFixedSize(160, 90) - self.end_frame_preview.setStyleSheet("background-color: #222; color: #888;") - end_preview_layout.addWidget(self.end_frame_preview) - previews_layout.addLayout(end_preview_layout) - - self.start_frame_input = QLineEdit() - self.end_frame_input = QLineEdit() - self.duration_input = QLineEdit() - - self.autostart_checkbox = QCheckBox("Start Generation Immediately") - self.autostart_checkbox.setChecked(True) - - self.generate_button = QPushButton("Generate") - - layout.addWidget(self.previews_widget) - form_layout.addRow(self.autostart_checkbox) - form_layout.addRow(self.generate_button) - - layout.addLayout(form_layout) - - # --- Results Area --- - results_group = QGroupBox("Generated Clips (Hover to play, Click button to insert)") - results_layout = QVBoxLayout() - self.results_list = QListWidget() - self.results_list.setFlow(QListWidget.Flow.LeftToRight) - self.results_list.setWrapping(True) - self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust) - self.results_list.setSpacing(10) - results_layout.addWidget(self.results_list) - results_group.setLayout(results_layout) - layout.addWidget(results_group) - - layout.addStretch() - - self.generate_button.clicked.connect(self.generate) - - self.poll_timer = QTimer(self) - self.poll_timer.setInterval(3000) - self.poll_timer.timeout.connect(self.poll_for_output) - - def set_previews(self, start_pixmap, end_pixmap): - if start_pixmap: - self.start_frame_preview.setPixmap(start_pixmap) - else: - self.start_frame_preview.setText("N/A") - self.start_frame_preview.setPixmap(QPixmap()) - - if end_pixmap: - self.end_frame_preview.setPixmap(end_pixmap) - else: - self.end_frame_preview.setText("N/A") - self.end_frame_preview.setPixmap(QPixmap()) - - def start_polling(self): - self.poll_timer.start() - - def stop_polling(self): - self.poll_timer.stop() - - def handle_api_error(self, response, action="performing action"): - try: - error_msg = response.json().get("error", "Unknown error") - except requests.exceptions.JSONDecodeError: - error_msg = response.text - self.status_updated.emit(f"AI Joiner Error: {error_msg}") - QMessageBox.warning(self, "API Error", f"Failed while {action}.\n\nServer response:\n{error_msg}") - - def check_server_status(self): - try: - requests.get(f"{API_BASE_URL}/api/outputs", timeout=1) - self.status_updated.emit("AI Joiner: Connected to WanGP server.") - return True - except requests.exceptions.ConnectionError: - self.status_updated.emit("AI Joiner Error: Cannot connect to WanGP server.") - QMessageBox.critical(self, "Connection Error", f"Could not connect to the WanGP API at {API_BASE_URL}.\n\nPlease ensure wgptool.py is running.") - return False - - def generate(self): - payload = {} - - model_type = self.model_type_to_use - if model_type: - payload['model_type'] = model_type - - if self.start_frame_input.text(): - payload['start_frame'] = self.start_frame_input.text() - if self.end_frame_input.text(): - payload['end_frame'] = self.end_frame_input.text() - - if self.duration_input.text(): - payload['duration_sec'] = self.duration_input.text() - - payload['start_generation'] = self.autostart_checkbox.isChecked() - - self.status_updated.emit("AI Joiner: Sending parameters...") - try: - response = requests.post(f"{API_BASE_URL}/api/generate", json=payload) - if response.status_code == 200: - if payload['start_generation']: - self.status_updated.emit("AI Joiner: Generation sent to server. Polling for output...") - else: - self.status_updated.emit("AI Joiner: Parameters set. Polling for manually started generation...") - self.start_polling() - else: - self.handle_api_error(response, "setting parameters") - except requests.exceptions.RequestException as e: - self.status_updated.emit(f"AI Joiner Error: Connection error: {e}") - - def poll_for_output(self): - try: - response = requests.get(f"{API_BASE_URL}/api/outputs") - if response.status_code == 200: - data = response.json() - output_files = data.get("outputs", []) - - new_files = set(output_files) - self.processed_files - if new_files: - self.status_updated.emit(f"AI Joiner: Received {len(new_files)} new clip(s).") - for file_path in sorted(list(new_files)): - self.add_result_item(file_path) - self.processed_files.add(file_path) - else: - self.status_updated.emit("AI Joiner: Polling for output...") - else: - self.status_updated.emit("AI Joiner: Polling for output...") - except requests.exceptions.RequestException: - self.status_updated.emit("AI Joiner: Polling... (Connection issue)") - - def add_result_item(self, video_path): - item_widget = VideoResultItemWidget(video_path, self.plugin) - list_item = QListWidgetItem(self.results_list) - list_item.setSizeHint(item_widget.sizeHint()) - self.results_list.addItem(list_item) - self.results_list.setItemWidget(list_item, item_widget) - - def clear_results(self): - self.results_list.clear() - self.processed_files.clear() - -class Plugin(VideoEditorPlugin): - def initialize(self): - self.name = "AI Frame Joiner" - self.description = "Uses a local AI server to generate a video between two frames." - self.client_widget = WgpClientWidget(self) - self.dock_widget = None - self.active_region = None - self.temp_dir = None - self.insert_on_new_track = False - self.client_widget.status_updated.connect(self.update_main_status) - - def enable(self): - if not self.dock_widget: - self.dock_widget = self.app.add_dock_widget(self, self.client_widget, "AI Frame Joiner", show_on_creation=False) - - self.dock_widget.hide() - - self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu) - - def disable(self): - try: - self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu) - except TypeError: - pass - self._cleanup_temp_dir() - self.client_widget.stop_polling() - - def update_main_status(self, message): - self.app.status_label.setText(message) - - def _cleanup_temp_dir(self): - if self.temp_dir and os.path.exists(self.temp_dir): - shutil.rmtree(self.temp_dir) - self.temp_dir = None - - def _reset_state(self): - self.active_region = None - self.insert_on_new_track = False - self.client_widget.stop_polling() - self.client_widget.clear_results() - self._cleanup_temp_dir() - self.update_main_status("AI Joiner: Idle") - self.client_widget.set_previews(None, None) - self.client_widget.model_type_to_use = None - - def on_timeline_context_menu(self, menu, event): - region = self.app.timeline_widget.get_region_at_pos(event.pos()) - if region: - menu.addSeparator() - - start_sec, end_sec = region - start_data, _, _ = self.app.get_frame_data_at_time(start_sec) - end_data, _, _ = self.app.get_frame_data_at_time(end_sec) - - if start_data and end_data: - join_action = menu.addAction("Join Frames With AI") - join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) - - join_action_new_track = menu.addAction("Join Frames With AI (New Track)") - join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) - - create_action = menu.addAction("Create Frames With AI") - create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False)) - - create_action_new_track = menu.addAction("Create Frames With AI (New Track)") - create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True)) - - def setup_generator_for_region(self, region, on_new_track=False): - self._reset_state() - self.client_widget.model_type_to_use = "i2v_2_2" - self.active_region = region - self.insert_on_new_track = on_new_track - self.client_widget.previews_widget.setVisible(True) - - start_sec, end_sec = region - if not self.client_widget.check_server_status(): - return - - start_data, w, h = self.app.get_frame_data_at_time(start_sec) - end_data, _, _ = self.app.get_frame_data_at_time(end_sec) - - if not start_data or not end_data: - QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames for the selected region.") - return - - preview_size = QSize(160, 90) - start_pixmap, end_pixmap = None, None - - try: - start_img = QImage(start_data, w, h, QImage.Format.Format_RGB888) - start_pixmap = QPixmap.fromImage(start_img).scaled(preview_size, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - - end_img = QImage(end_data, w, h, QImage.Format.Format_RGB888) - end_pixmap = QPixmap.fromImage(end_img).scaled(preview_size, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - except Exception as e: - QMessageBox.critical(self.app, "Image Error", f"Could not create preview images: {e}") - - self.client_widget.set_previews(start_pixmap, end_pixmap) - - try: - self.temp_dir = tempfile.mkdtemp(prefix="ai_joiner_") - start_img_path = os.path.join(self.temp_dir, "start_frame.png") - end_img_path = os.path.join(self.temp_dir, "end_frame.png") - - QImage(start_data, w, h, QImage.Format.Format_RGB888).save(start_img_path) - QImage(end_data, w, h, QImage.Format.Format_RGB888).save(end_img_path) - except Exception as e: - QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}") - self._cleanup_temp_dir() - return - - duration_sec = end_sec - start_sec - self.client_widget.duration_input.setText(str(duration_sec)) - - self.client_widget.start_frame_input.setText(start_img_path) - self.client_widget.end_frame_input.setText(end_img_path) - self.update_main_status(f"AI Joiner: Ready for region {start_sec:.2f}s - {end_sec:.2f}s") - - self.dock_widget.show() - self.dock_widget.raise_() - - def setup_creator_for_region(self, region, on_new_track=False): - self._reset_state() - self.client_widget.model_type_to_use = "t2v_2_2" - self.active_region = region - self.insert_on_new_track = on_new_track - self.client_widget.previews_widget.setVisible(False) - - start_sec, end_sec = region - if not self.client_widget.check_server_status(): - return - - self.client_widget.set_previews(None, None) - - duration_sec = end_sec - start_sec - self.client_widget.duration_input.setText(str(duration_sec)) - - self.client_widget.start_frame_input.clear() - self.client_widget.end_frame_input.clear() - - self.update_main_status(f"AI Creator: Ready for region {start_sec:.2f}s - {end_sec:.2f}s") - - self.dock_widget.show() - self.dock_widget.raise_() - - def insert_generated_clip(self, video_path): - from main import TimelineClip - - if not self.active_region: - self.update_main_status("AI Joiner Error: No active region to insert into.") - return - - if not os.path.exists(video_path): - self.update_main_status(f"AI Joiner Error: Output file not found: {video_path}") - return - - start_sec, end_sec = self.active_region - is_new_track_mode = self.insert_on_new_track - - self.update_main_status(f"AI Joiner: Inserting clip {os.path.basename(video_path)}") - - def complex_insertion_action(): - probe = ffmpeg.probe(video_path) - actual_duration = float(probe['format']['duration']) - - if is_new_track_mode: - self.app.timeline.num_video_tracks += 1 - new_track_index = self.app.timeline.num_video_tracks - - new_clip = TimelineClip( - source_path=video_path, - timeline_start_sec=start_sec, - clip_start_sec=0, - duration_sec=actual_duration, - track_index=new_track_index, - track_type='video', - media_type='video', - group_id=str(uuid.uuid4()) - ) - self.app.timeline.add_clip(new_clip) - else: - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) - - clips_to_remove = [ - c for c in self.app.timeline.clips - if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec - ] - for clip in clips_to_remove: - if clip in self.app.timeline.clips: - self.app.timeline.clips.remove(clip) - - new_clip = TimelineClip( - source_path=video_path, - timeline_start_sec=start_sec, - clip_start_sec=0, - duration_sec=actual_duration, - track_index=1, - track_type='video', - media_type='video', - group_id=str(uuid.uuid4()) - ) - self.app.timeline.add_clip(new_clip) - - try: - self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action) - - self.app.prune_empty_tracks() - self.update_main_status("AI clip inserted successfully.") - - for i in range(self.client_widget.results_list.count()): - item = self.client_widget.results_list.item(i) - widget = self.client_widget.results_list.itemWidget(item) - if widget and widget.video_path == video_path: - self.client_widget.results_list.takeItem(i) - break - - except Exception as e: - error_message = f"AI Joiner Error during clip insertion/probing: {e}" - self.update_main_status(error_message) - print(error_message) \ No newline at end of file From 4f9c03bf4d1546bfc0837f43bdf2dd4e321931e6 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 21:19:09 +1100 Subject: [PATCH 119/155] update screenshots --- README.md | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 3726bd212..8f30a9927 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,9 @@ python videoeditor.py ``` ## Screenshots -image -image -image -image +image +image +image ## Credits The AI Video Generator plugin is built from a desktop port of WAN2GP by DeepBeepMeep. From 1622655827b1f904ccaf104807013a68def53739 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 21:27:19 +1100 Subject: [PATCH 120/155] update install instructions --- README.md | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f30a9927..15764d5d3 100644 --- a/README.md +++ b/README.md @@ -24,16 +24,32 @@ A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provi ## Installation + +#### **Windows** + ```bash git clone https://github.com/Tophness/Wan2GP.git cd Wan2GP git checkout video_editor -conda create -n wan2gp python=3.10.9 -conda activate wan2gp +python -m venv venv +venv\Scripts\activate +pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 +pip install -r requirements.txt +``` + +#### **Linux / macOS** + +``` +git clone https://github.com/Tophness/Wan2GP.git +cd Wan2GP +git checkout video_editor +python3 -m venv venv +source venv/bin/activate pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 pip install -r requirements.txt ``` + ## Usage **Run the video editor:** From 4aded7a158fab544fc66f607247e3211d468c333 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 22:15:37 +1100 Subject: [PATCH 121/155] Update README.md --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 15764d5d3..712fbfba4 100644 --- a/README.md +++ b/README.md @@ -59,10 +59,11 @@ python videoeditor.py ## Screenshots image -image -image +image +image + ## Credits The AI Video Generator plugin is built from a desktop port of WAN2GP by DeepBeepMeep. See WAN2GP for more details. -https://github.com/deepbeepmeep/Wan2GP \ No newline at end of file +https://github.com/deepbeepmeep/Wan2GP From 8c9286cbc4b80ae9a167b1317ff30e69a339d6f4 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 22:19:14 +1100 Subject: [PATCH 122/155] Update README.md --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 712fbfba4..1fc0c3473 100644 --- a/README.md +++ b/README.md @@ -59,8 +59,9 @@ python videoeditor.py ## Screenshots image -image -image +image +image + ## Credits From f7d62b539a9d0aa6f69dfc81ab72dbc246064647 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 22:22:45 +1100 Subject: [PATCH 123/155] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1fc0c3473..eff58e558 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ python videoeditor.py ``` ## Screenshots -image +image image image From 0953b56c0fc26a5aa737054139b517f2f591c568 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 23:31:21 +1100 Subject: [PATCH 124/155] Add create to new track option back --- plugins/wan2gp/main.py | 3417 +++++++++++++++++++++------------------- 1 file changed, 1764 insertions(+), 1653 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index ca201dcda..0b925227c 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -1,1653 +1,1764 @@ -import sys -import os -import threading -import time -import json -import tempfile -import shutil -import uuid -from unittest.mock import MagicMock -from pathlib import Path - -# --- Start of Gradio Hijacking --- -# This block creates a mock Gradio module. When wgp.py is imported, -# all calls to `gr.*` will be intercepted by these mock objects, -# preventing any UI from being built and allowing us to use the -# backend logic directly. - -class MockGradioComponent(MagicMock): - """A smarter mock that captures constructor arguments.""" - def __init__(self, *args, **kwargs): - super().__init__(name=f"gr.{kwargs.get('elem_id', 'component')}") - self.kwargs = kwargs - self.value = kwargs.get('value') - self.choices = kwargs.get('choices') - - for method in ['then', 'change', 'click', 'input', 'select', 'upload', 'mount', 'launch', 'on', 'release']: - setattr(self, method, lambda *a, **kw: self) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - -class MockGradioError(Exception): - pass - -class MockGradioModule: - def __getattr__(self, name): - if name == 'Error': - return lambda *args, **kwargs: MockGradioError(*args) - - if name in ['Info', 'Warning']: - return lambda *args, **kwargs: print(f"Intercepted gr.{name}:", *args) - - return lambda *args, **kwargs: MockGradioComponent(*args, **kwargs) - -sys.modules['gradio'] = MockGradioModule() -sys.modules['gradio.gallery'] = MockGradioModule() -sys.modules['shared.gradio.gallery'] = MockGradioModule() -# --- End of Gradio Hijacking --- - -import ffmpeg - -from PyQt6.QtWidgets import ( - QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, - QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox, - QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem, - QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem, - QMessageBox, QRadioButton, QSizePolicy -) -from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QSize, QRectF -from PyQt6.QtGui import QPixmap, QImage, QDropEvent -from PyQt6.QtMultimedia import QMediaPlayer -from PyQt6.QtMultimediaWidgets import QVideoWidget -from PIL.ImageQt import ImageQt - -# Import the base plugin class from the main application's path -sys.path.append(str(Path(__file__).parent.parent.parent)) -from plugins import VideoEditorPlugin - -class VideoResultItemWidget(QWidget): - """A widget to display a generated video with a hover-to-play preview and insert button.""" - def __init__(self, video_path, plugin, parent=None): - super().__init__(parent) - self.video_path = video_path - self.plugin = plugin - self.app = plugin.app - self.duration = 0.0 - self.has_audio = False - - self.setMinimumSize(200, 180) - self.setMaximumHeight(190) - - layout = QVBoxLayout(self) - layout.setContentsMargins(5, 5, 5, 5) - layout.setSpacing(5) - - self.media_player = QMediaPlayer() - self.video_widget = QVideoWidget() - self.video_widget.setFixedSize(160, 90) - self.media_player.setVideoOutput(self.video_widget) - self.media_player.setSource(QUrl.fromLocalFile(self.video_path)) - self.media_player.setLoops(QMediaPlayer.Loops.Infinite) - - self.info_label = QLabel(os.path.basename(video_path)) - self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.info_label.setWordWrap(True) - - self.insert_button = QPushButton("Insert into Timeline") - self.insert_button.clicked.connect(self.on_insert) - - h_layout = QHBoxLayout() - h_layout.addStretch() - h_layout.addWidget(self.video_widget) - h_layout.addStretch() - - layout.addLayout(h_layout) - layout.addWidget(self.info_label) - layout.addWidget(self.insert_button) - self.probe_video() - - def probe_video(self): - try: - probe = ffmpeg.probe(self.video_path) - self.duration = float(probe['format']['duration']) - self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', [])) - self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)") - except Exception as e: - self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}") - print(f"Error probing video {self.video_path}: {e}") - - def enterEvent(self, event): - super().enterEvent(event) - self.media_player.play() - if not self.plugin.active_region or self.duration == 0: return - start_sec, _ = self.plugin.active_region - timeline = self.app.timeline_widget - video_rect, audio_rect = None, None - x = timeline.sec_to_x(start_sec) - w = int(self.duration * timeline.pixels_per_second) - if self.plugin.insert_on_new_track: - video_y = timeline.TIMESCALE_HEIGHT - video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) - if self.has_audio: - audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT - audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) - else: - v_track_idx = 1 - visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx - video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT - video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) - if self.has_audio: - a_track_idx = 1 - visual_a_idx = a_track_idx - 1 - audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT - audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) - timeline.set_hover_preview_rects(video_rect, audio_rect) - - def leaveEvent(self, event): - super().leaveEvent(event) - self.media_player.pause() - self.media_player.setPosition(0) - self.app.timeline_widget.set_hover_preview_rects(None, None) - - def on_insert(self): - self.media_player.stop() - self.media_player.setSource(QUrl()) - self.media_player.setVideoOutput(None) - self.app.timeline_widget.set_hover_preview_rects(None, None) - self.plugin.insert_generated_clip(self.video_path) - - -class QueueTableWidget(QTableWidget): - rowsMoved = pyqtSignal(int, int) - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setDragEnabled(True) - self.setAcceptDrops(True) - self.setDropIndicatorShown(True) - self.setDragDropMode(self.DragDropMode.InternalMove) - self.setSelectionBehavior(self.SelectionBehavior.SelectRows) - self.setSelectionMode(self.SelectionMode.SingleSelection) - - def dropEvent(self, event: QDropEvent): - if event.source() == self and event.dropAction() == Qt.DropAction.MoveAction: - source_row = self.currentRow() - target_item = self.itemAt(event.position().toPoint()) - dest_row = target_item.row() if target_item else self.rowCount() - if source_row < dest_row: dest_row -=1 - if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row) - event.acceptProposedAction() - else: - super().dropEvent(event) - - -class Worker(QObject): - progress = pyqtSignal(list) - status = pyqtSignal(str) - preview = pyqtSignal(object) - output = pyqtSignal() - finished = pyqtSignal() - error = pyqtSignal(str) - def __init__(self, plugin, state): - super().__init__() - self.plugin = plugin - self.wgp = plugin.wgp - self.state = state - self._is_running = True - self._last_progress_phase = None - self._last_preview = None - - def run(self): - def generation_target(): - try: - for _ in self.wgp.process_tasks(self.state): - if self._is_running: self.output.emit() - else: break - except Exception as e: - import traceback - print("Error in generation thread:") - traceback.print_exc() - if "gradio.Error" in str(type(e)): self.error.emit(str(e)) - else: self.error.emit(f"An unexpected error occurred: {e}") - finally: - self._is_running = False - gen_thread = threading.Thread(target=generation_target, daemon=True) - gen_thread.start() - while self._is_running: - gen = self.state.get('gen', {}) - current_phase = gen.get("progress_phase") - if current_phase and current_phase != self._last_progress_phase: - self._last_progress_phase = current_phase - phase_name, step = current_phase - total_steps = gen.get("num_inference_steps", 1) - high_level_status = gen.get("progress_status", "") - status_msg = self.wgp.merge_status_context(high_level_status, phase_name) - progress_args = [(step, total_steps), status_msg] - self.progress.emit(progress_args) - preview_img = gen.get('preview') - if preview_img is not None and preview_img is not self._last_preview: - self._last_preview = preview_img - self.preview.emit(preview_img) - gen['preview'] = None - time.sleep(0.1) - gen_thread.join() - self.finished.emit() - - -class WgpDesktopPluginWidget(QWidget): - def __init__(self, plugin): - super().__init__() - self.plugin = plugin - self.wgp = plugin.wgp - self.widgets = {} - self.state = {} - self.worker = None - self.thread = None - self.lora_map = {} - self.full_resolution_choices = [] - self.main_config = {} - self.processed_files = set() - - self.load_main_config() - self.setup_ui() - self.apply_initial_config() - self.connect_signals() - self.init_wgp_state() - - def setup_ui(self): - main_layout = QVBoxLayout(self) - self.header_info = QLabel("Header Info") - main_layout.addWidget(self.header_info) - self.tabs = QTabWidget() - main_layout.addWidget(self.tabs) - self.setup_generator_tab() - self.setup_config_tab() - - def create_widget(self, widget_class, name, *args, **kwargs): - widget = widget_class(*args, **kwargs) - self.widgets[name] = widget - return widget - - def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1.0, precision=1): - container = QWidget() - hbox = QHBoxLayout(container) - hbox.setContentsMargins(0, 0, 0, 0) - slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal) - slider.setRange(min_val, max_val) - slider.setValue(int(initial_val * scale)) - value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}") - value_label.setMinimumWidth(50) - slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}")) - hbox.addWidget(slider) - hbox.addWidget(value_label) - return container - - def _create_file_input(self, name, label_text): - container = self.create_widget(QWidget, f"{name}_container") - hbox = QHBoxLayout(container) - hbox.setContentsMargins(0, 0, 0, 0) - line_edit = self.create_widget(QLineEdit, name) - line_edit.setPlaceholderText("No file selected or path pasted") - button = QPushButton("Browse...") - def open_dialog(): - if "refs" in name: - filenames, _ = QFileDialog.getOpenFileNames(self, f"Select {label_text}") - if filenames: line_edit.setText(";".join(filenames)) - else: - filename, _ = QFileDialog.getOpenFileName(self, f"Select {label_text}") - if filename: line_edit.setText(filename) - button.clicked.connect(open_dialog) - clear_button = QPushButton("X") - clear_button.setFixedWidth(30) - clear_button.clicked.connect(lambda: line_edit.clear()) - hbox.addWidget(QLabel(f"{label_text}:")) - hbox.addWidget(line_edit, 1) - hbox.addWidget(button) - hbox.addWidget(clear_button) - return container - - def setup_generator_tab(self): - gen_tab = QWidget() - self.tabs.addTab(gen_tab, "Video Generator") - gen_layout = QHBoxLayout(gen_tab) - left_panel = QWidget() - left_layout = QVBoxLayout(left_panel) - gen_layout.addWidget(left_panel, 1) - right_panel = QWidget() - right_layout = QVBoxLayout(right_panel) - gen_layout.addWidget(right_panel, 1) - scroll_area = QScrollArea() - scroll_area.setWidgetResizable(True) - left_layout.addWidget(scroll_area) - options_widget = QWidget() - scroll_area.setWidget(options_widget) - options_layout = QVBoxLayout(options_widget) - model_layout = QHBoxLayout() - self.widgets['model_family'] = QComboBox() - self.widgets['model_base_type_choice'] = QComboBox() - self.widgets['model_choice'] = QComboBox() - model_layout.addWidget(QLabel("Model:")) - model_layout.addWidget(self.widgets['model_family'], 2) - model_layout.addWidget(self.widgets['model_base_type_choice'], 3) - model_layout.addWidget(self.widgets['model_choice'], 3) - options_layout.addLayout(model_layout) - options_layout.addWidget(QLabel("Prompt:")) - self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100) - options_layout.addWidget(self.widgets['prompt']) - options_layout.addWidget(QLabel("Negative Prompt:")) - self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60) - options_layout.addWidget(self.widgets['negative_prompt']) - basic_group = QGroupBox("Basic Options") - basic_layout = QFormLayout(basic_group) - res_container = QWidget() - res_hbox = QHBoxLayout(res_container) - res_hbox.setContentsMargins(0, 0, 0, 0) - res_hbox.addWidget(self.create_widget(QComboBox, 'resolution_group'), 2) - res_hbox.addWidget(self.create_widget(QComboBox, 'resolution'), 3) - basic_layout.addRow("Resolution:", res_container) - basic_layout.addRow("Video Length:", self._create_slider_with_label('video_length', 1, 737, 81, 1.0, 0)) - basic_layout.addRow("Inference Steps:", self._create_slider_with_label('num_inference_steps', 1, 100, 30, 1.0, 0)) - basic_layout.addRow("Seed:", self.create_widget(QLineEdit, 'seed', '-1')) - options_layout.addWidget(basic_group) - mode_options_group = QGroupBox("Generation Mode & Input Options") - mode_options_layout = QVBoxLayout(mode_options_group) - mode_hbox = QHBoxLayout() - mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_t', "Text Prompt Only")) - mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_s', "Start with Image")) - mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_v', "Continue Video")) - mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_l', "Continue Last Video")) - self.widgets['mode_t'].setChecked(True) - mode_options_layout.addLayout(mode_hbox) - options_hbox = QHBoxLayout() - options_hbox.addWidget(self.create_widget(QCheckBox, 'image_end_checkbox', "Use End Image")) - options_hbox.addWidget(self.create_widget(QCheckBox, 'control_video_checkbox', "Use Control Video")) - options_hbox.addWidget(self.create_widget(QCheckBox, 'ref_image_checkbox', "Use Reference Image(s)")) - mode_options_layout.addLayout(options_hbox) - options_layout.addWidget(mode_options_group) - inputs_group = QGroupBox("Inputs") - inputs_layout = QVBoxLayout(inputs_group) - inputs_layout.addWidget(self._create_file_input('image_start', "Start Image")) - inputs_layout.addWidget(self._create_file_input('image_end', "End Image")) - inputs_layout.addWidget(self._create_file_input('video_source', "Source Video")) - inputs_layout.addWidget(self._create_file_input('video_guide', "Control Video")) - inputs_layout.addWidget(self._create_file_input('video_mask', "Video Mask")) - inputs_layout.addWidget(self._create_file_input('image_refs', "Reference Image(s)")) - denoising_row = QFormLayout() - denoising_row.addRow("Denoising Strength:", self._create_slider_with_label('denoising_strength', 0, 100, 50, 100.0, 2)) - inputs_layout.addLayout(denoising_row) - options_layout.addWidget(inputs_group) - self.advanced_group = self.create_widget(QGroupBox, 'advanced_group', "Advanced Options") - self.advanced_group.setCheckable(True) - self.advanced_group.setChecked(False) - advanced_layout = QVBoxLayout(self.advanced_group) - advanced_tabs = self.create_widget(QTabWidget, 'advanced_tabs') - advanced_layout.addWidget(advanced_tabs) - self._setup_adv_tab_general(advanced_tabs) - self._setup_adv_tab_loras(advanced_tabs) - self._setup_adv_tab_speed(advanced_tabs) - self._setup_adv_tab_postproc(advanced_tabs) - self._setup_adv_tab_audio(advanced_tabs) - self._setup_adv_tab_quality(advanced_tabs) - self._setup_adv_tab_sliding_window(advanced_tabs) - self._setup_adv_tab_misc(advanced_tabs) - options_layout.addWidget(self.advanced_group) - - btn_layout = QHBoxLayout() - self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate") - self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue") - self.generate_btn.setEnabled(True) - self.add_to_queue_btn.setEnabled(False) - btn_layout.addWidget(self.generate_btn) - btn_layout.addWidget(self.add_to_queue_btn) - right_layout.addLayout(btn_layout) - self.status_label = self.create_widget(QLabel, 'status_label', "Idle") - right_layout.addWidget(self.status_label) - self.progress_bar = self.create_widget(QProgressBar, 'progress_bar') - right_layout.addWidget(self.progress_bar) - preview_group = self.create_widget(QGroupBox, 'preview_group', "Preview") - preview_group.setCheckable(True) - preview_group.setStyleSheet("QGroupBox { border: 1px solid #cccccc; }") - preview_group_layout = QVBoxLayout(preview_group) - self.preview_image = self.create_widget(QLabel, 'preview_image', "") - self.preview_image.setAlignment(Qt.AlignmentFlag.AlignCenter) - self.preview_image.setMinimumSize(200, 200) - preview_group_layout.addWidget(self.preview_image) - right_layout.addWidget(preview_group) - - results_group = QGroupBox("Generated Clips") - results_group.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) - results_layout = QVBoxLayout(results_group) - self.results_list = self.create_widget(QListWidget, 'results_list') - self.results_list.setFlow(QListWidget.Flow.LeftToRight) - self.results_list.setWrapping(True) - self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust) - self.results_list.setSpacing(10) - results_layout.addWidget(self.results_list) - right_layout.addWidget(results_group) - - right_layout.addWidget(QLabel("Queue:")) - self.queue_table = self.create_widget(QueueTableWidget, 'queue_table') - right_layout.addWidget(self.queue_table) - queue_btn_layout = QHBoxLayout() - self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected") - self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue") - self.abort_btn = self.create_widget(QPushButton, 'abort_btn', "Abort") - queue_btn_layout.addWidget(self.remove_queue_btn) - queue_btn_layout.addWidget(self.clear_queue_btn) - queue_btn_layout.addWidget(self.abort_btn) - right_layout.addLayout(queue_btn_layout) - - def _setup_adv_tab_general(self, tabs): - tab = QWidget() - tabs.addTab(tab, "General") - layout = QFormLayout(tab) - self.widgets['adv_general_layout'] = layout - guidance_group = QGroupBox("Guidance") - guidance_layout = self.create_widget(QFormLayout, 'guidance_layout', guidance_group) - guidance_layout.addRow("Guidance (CFG):", self._create_slider_with_label('guidance_scale', 10, 200, 5.0, 10.0, 1)) - self.widgets['guidance_phases_row_index'] = guidance_layout.rowCount() - guidance_layout.addRow("Guidance Phases:", self.create_widget(QComboBox, 'guidance_phases')) - self.widgets['guidance2_row_index'] = guidance_layout.rowCount() - guidance_layout.addRow("Guidance 2:", self._create_slider_with_label('guidance2_scale', 10, 200, 5.0, 10.0, 1)) - self.widgets['guidance3_row_index'] = guidance_layout.rowCount() - guidance_layout.addRow("Guidance 3:", self._create_slider_with_label('guidance3_scale', 10, 200, 5.0, 10.0, 1)) - self.widgets['switch_thresh_row_index'] = guidance_layout.rowCount() - guidance_layout.addRow("Switch Threshold:", self._create_slider_with_label('switch_threshold', 0, 1000, 0, 1.0, 0)) - layout.addRow(guidance_group) - nag_group = self.create_widget(QGroupBox, 'nag_group', "NAG (Negative Adversarial Guidance)") - nag_layout = QFormLayout(nag_group) - nag_layout.addRow("NAG Scale:", self._create_slider_with_label('NAG_scale', 10, 200, 1.0, 10.0, 1)) - nag_layout.addRow("NAG Tau:", self._create_slider_with_label('NAG_tau', 10, 50, 3.5, 10.0, 1)) - nag_layout.addRow("NAG Alpha:", self._create_slider_with_label('NAG_alpha', 0, 20, 0.5, 10.0, 1)) - layout.addRow(nag_group) - self.widgets['solver_row_container'] = QWidget() - solver_hbox = QHBoxLayout(self.widgets['solver_row_container']) - solver_hbox.setContentsMargins(0,0,0,0) - solver_hbox.addWidget(QLabel("Sampler Solver:")) - solver_hbox.addWidget(self.create_widget(QComboBox, 'sample_solver')) - layout.addRow(self.widgets['solver_row_container']) - self.widgets['flow_shift_row_index'] = layout.rowCount() - layout.addRow("Shift Scale:", self._create_slider_with_label('flow_shift', 10, 250, 3.0, 10.0, 1)) - self.widgets['audio_guidance_row_index'] = layout.rowCount() - layout.addRow("Audio Guidance:", self._create_slider_with_label('audio_guidance_scale', 10, 200, 4.0, 10.0, 1)) - self.widgets['repeat_generation_row_index'] = layout.rowCount() - layout.addRow("Repeat Generations:", self._create_slider_with_label('repeat_generation', 1, 25, 1, 1.0, 0)) - combo = self.create_widget(QComboBox, 'multi_images_gen_type') - combo.addItem("Generate all combinations", 0) - combo.addItem("Match images and texts", 1) - self.widgets['multi_images_gen_type_row_index'] = layout.rowCount() - layout.addRow("Multi-Image Mode:", combo) - - def _setup_adv_tab_loras(self, tabs): - tab = QWidget() - tabs.addTab(tab, "Loras") - layout = QVBoxLayout(tab) - layout.addWidget(QLabel("Available Loras (Ctrl+Click to select multiple):")) - lora_list = self.create_widget(QListWidget, 'activated_loras') - lora_list.setSelectionMode(QListWidget.SelectionMode.MultiSelection) - layout.addWidget(lora_list) - layout.addWidget(QLabel("Loras Multipliers:")) - layout.addWidget(self.create_widget(QTextEdit, 'loras_multipliers')) - - def _setup_adv_tab_speed(self, tabs): - tab = QWidget() - tabs.addTab(tab, "Speed") - layout = QFormLayout(tab) - combo = self.create_widget(QComboBox, 'skip_steps_cache_type') - combo.addItem("None", "") - combo.addItem("Tea Cache", "tea") - combo.addItem("Mag Cache", "mag") - layout.addRow("Cache Type:", combo) - combo = self.create_widget(QComboBox, 'skip_steps_multiplier') - combo.addItem("x1.5 speed up", 1.5) - combo.addItem("x1.75 speed up", 1.75) - combo.addItem("x2.0 speed up", 2.0) - combo.addItem("x2.25 speed up", 2.25) - combo.addItem("x2.5 speed up", 2.5) - layout.addRow("Acceleration:", combo) - layout.addRow("Start %:", self._create_slider_with_label('skip_steps_start_step_perc', 0, 100, 0, 1.0, 0)) - - def _setup_adv_tab_postproc(self, tabs): - tab = QWidget() - tabs.addTab(tab, "Post-Processing") - layout = QFormLayout(tab) - combo = self.create_widget(QComboBox, 'temporal_upsampling') - combo.addItem("Disabled", "") - combo.addItem("Rife x2 frames/s", "rife2") - combo.addItem("Rife x4 frames/s", "rife4") - layout.addRow("Temporal Upsampling:", combo) - combo = self.create_widget(QComboBox, 'spatial_upsampling') - combo.addItem("Disabled", "") - combo.addItem("Lanczos x1.5", "lanczos1.5") - combo.addItem("Lanczos x2.0", "lanczos2") - layout.addRow("Spatial Upsampling:", combo) - layout.addRow("Film Grain Intensity:", self._create_slider_with_label('film_grain_intensity', 0, 100, 0, 100.0, 2)) - layout.addRow("Film Grain Saturation:", self._create_slider_with_label('film_grain_saturation', 0, 100, 0.5, 100.0, 2)) - - def _setup_adv_tab_audio(self, tabs): - tab = QWidget() - tabs.addTab(tab, "Audio") - layout = QFormLayout(tab) - combo = self.create_widget(QComboBox, 'MMAudio_setting') - combo.addItem("Disabled", 0) - combo.addItem("Enabled", 1) - layout.addRow("MMAudio:", combo) - layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_prompt', placeholderText="MMAudio Prompt")) - layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_neg_prompt', placeholderText="MMAudio Negative Prompt")) - layout.addRow(self._create_file_input('audio_source', "Custom Soundtrack")) - - def _setup_adv_tab_quality(self, tabs): - tab = QWidget() - tabs.addTab(tab, "Quality") - layout = QVBoxLayout(tab) - slg_group = self.create_widget(QGroupBox, 'slg_group', "Skip Layer Guidance") - slg_layout = QFormLayout(slg_group) - slg_combo = self.create_widget(QComboBox, 'slg_switch') - slg_combo.addItem("OFF", 0) - slg_combo.addItem("ON", 1) - slg_layout.addRow("Enable SLG:", slg_combo) - slg_layout.addRow("Start %:", self._create_slider_with_label('slg_start_perc', 0, 100, 10, 1.0, 0)) - slg_layout.addRow("End %:", self._create_slider_with_label('slg_end_perc', 0, 100, 90, 1.0, 0)) - layout.addWidget(slg_group) - quality_form = QFormLayout() - self.widgets['quality_form_layout'] = quality_form - apg_combo = self.create_widget(QComboBox, 'apg_switch') - apg_combo.addItem("OFF", 0) - apg_combo.addItem("ON", 1) - self.widgets['apg_switch_row_index'] = quality_form.rowCount() - quality_form.addRow("Adaptive Projected Guidance:", apg_combo) - cfg_star_combo = self.create_widget(QComboBox, 'cfg_star_switch') - cfg_star_combo.addItem("OFF", 0) - cfg_star_combo.addItem("ON", 1) - self.widgets['cfg_star_switch_row_index'] = quality_form.rowCount() - quality_form.addRow("Classifier-Free Guidance Star:", cfg_star_combo) - self.widgets['cfg_zero_step_row_index'] = quality_form.rowCount() - quality_form.addRow("CFG Zero below Layer:", self._create_slider_with_label('cfg_zero_step', -1, 39, -1, 1.0, 0)) - combo = self.create_widget(QComboBox, 'min_frames_if_references') - combo.addItem("Disabled (1 frame)", 1) - combo.addItem("Generate 5 frames", 5) - combo.addItem("Generate 9 frames", 9) - combo.addItem("Generate 13 frames", 13) - combo.addItem("Generate 17 frames", 17) - self.widgets['min_frames_if_references_row_index'] = quality_form.rowCount() - quality_form.addRow("Min Frames for Quality:", combo) - layout.addLayout(quality_form) - - def _setup_adv_tab_sliding_window(self, tabs): - tab = QWidget() - self.widgets['sliding_window_tab_index'] = tabs.count() - tabs.addTab(tab, "Sliding Window") - layout = QFormLayout(tab) - layout.addRow("Window Size:", self._create_slider_with_label('sliding_window_size', 5, 257, 129, 1.0, 0)) - layout.addRow("Overlap:", self._create_slider_with_label('sliding_window_overlap', 1, 97, 5, 1.0, 0)) - layout.addRow("Color Correction:", self._create_slider_with_label('sliding_window_color_correction_strength', 0, 100, 0, 100.0, 2)) - layout.addRow("Overlap Noise:", self._create_slider_with_label('sliding_window_overlap_noise', 0, 150, 20, 1.0, 0)) - layout.addRow("Discard Last Frames:", self._create_slider_with_label('sliding_window_discard_last_frames', 0, 20, 0, 1.0, 0)) - - def _setup_adv_tab_misc(self, tabs): - tab = QWidget() - tabs.addTab(tab, "Misc") - layout = QFormLayout(tab) - self.widgets['misc_layout'] = layout - riflex_combo = self.create_widget(QComboBox, 'RIFLEx_setting') - riflex_combo.addItem("Auto", 0) - riflex_combo.addItem("Always ON", 1) - riflex_combo.addItem("Always OFF", 2) - self.widgets['riflex_row_index'] = layout.rowCount() - layout.addRow("RIFLEx Setting:", riflex_combo) - fps_combo = self.create_widget(QComboBox, 'force_fps') - layout.addRow("Force FPS:", fps_combo) - profile_combo = self.create_widget(QComboBox, 'override_profile') - profile_combo.addItem("Default Profile", -1) - for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val) - layout.addRow("Override Memory Profile:", profile_combo) - combo = self.create_widget(QComboBox, 'multi_prompts_gen_type') - combo.addItem("Generate new Video per line", 0) - combo.addItem("Use line for new Sliding Window", 1) - layout.addRow("Multi-Prompt Mode:", combo) - - def setup_config_tab(self): - config_tab = QWidget() - self.tabs.addTab(config_tab, "Configuration") - main_layout = QVBoxLayout(config_tab) - self.config_status_label = QLabel("Apply changes for them to take effect. Some may require a restart.") - main_layout.addWidget(self.config_status_label) - config_tabs = QTabWidget() - main_layout.addWidget(config_tabs) - config_tabs.addTab(self._create_general_config_tab(), "General") - config_tabs.addTab(self._create_performance_config_tab(), "Performance") - config_tabs.addTab(self._create_extensions_config_tab(), "Extensions") - config_tabs.addTab(self._create_outputs_config_tab(), "Outputs") - config_tabs.addTab(self._create_notifications_config_tab(), "Notifications") - self.apply_config_btn = QPushButton("Apply Changes") - self.apply_config_btn.clicked.connect(self._on_apply_config_changes) - main_layout.addWidget(self.apply_config_btn) - - def _create_scrollable_form_tab(self): - tab_widget = QWidget() - scroll_area = QScrollArea() - scroll_area.setWidgetResizable(True) - layout = QVBoxLayout(tab_widget) - layout.addWidget(scroll_area) - content_widget = QWidget() - form_layout = QFormLayout(content_widget) - scroll_area.setWidget(content_widget) - return tab_widget, form_layout - - def _create_config_combo(self, form_layout, label, key, choices, default_value): - combo = QComboBox() - for text, data in choices: combo.addItem(text, data) - index = combo.findData(self.wgp.server_config.get(key, default_value)) - if index != -1: combo.setCurrentIndex(index) - self.widgets[f'config_{key}'] = combo - form_layout.addRow(label, combo) - - def _create_config_slider(self, form_layout, label, key, min_val, max_val, default_value, step=1): - container = QWidget() - hbox = QHBoxLayout(container) - hbox.setContentsMargins(0,0,0,0) - slider = QSlider(Qt.Orientation.Horizontal) - slider.setRange(min_val, max_val) - slider.setSingleStep(step) - slider.setValue(self.wgp.server_config.get(key, default_value)) - value_label = QLabel(str(slider.value())) - value_label.setMinimumWidth(40) - slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v))) - hbox.addWidget(slider) - hbox.addWidget(value_label) - self.widgets[f'config_{key}'] = slider - form_layout.addRow(label, container) - - def _create_config_checklist(self, form_layout, label, key, choices, default_value): - list_widget = QListWidget() - list_widget.setMinimumHeight(100) - current_values = self.wgp.server_config.get(key, default_value) - for text, data in choices: - item = QListWidgetItem(text) - item.setData(Qt.ItemDataRole.UserRole, data) - item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) - item.setCheckState(Qt.CheckState.Checked if data in current_values else Qt.CheckState.Unchecked) - list_widget.addItem(item) - self.widgets[f'config_{key}'] = list_widget - form_layout.addRow(label, list_widget) - - def _create_config_textbox(self, form_layout, label, key, default_value, multi_line=False): - if multi_line: - textbox = QTextEdit(default_value) - textbox.setAcceptRichText(False) - else: - textbox = QLineEdit(default_value) - self.widgets[f'config_{key}'] = textbox - form_layout.addRow(label, textbox) - - def _create_general_config_tab(self): - tab, form = self._create_scrollable_form_tab() - _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False) - self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types) - self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1) - self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0) - self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto") - self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata") - self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], []) - self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5) - self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0) - self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1) - checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths)) - checkpoints_textbox = QTextEdit() - checkpoints_textbox.setPlainText(checkpoints_paths_text) - checkpoints_textbox.setAcceptRichText(False) - checkpoints_textbox.setMinimumHeight(60) - self.widgets['config_checkpoints_paths'] = checkpoints_textbox - form.addRow("Checkpoints Paths:", checkpoints_textbox) - self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default") - return tab - - def _create_performance_config_tab(self): - tab, form = self._create_scrollable_form_tab() - self._create_config_combo(form, "Transformer Quantization:", "transformer_quantization", [("Scaled Int8 (recommended)", "int8"), ("16-bit (no quantization)", "bf16")], "int8") - self._create_config_combo(form, "Transformer Data Type:", "transformer_dtype_policy", [("Best Supported by Hardware", ""), ("FP16", "fp16"), ("BF16", "bf16")], "") - self._create_config_combo(form, "Transformer Calculation:", "mixed_precision", [("16-bit only", "0"), ("Mixed 16/32-bit (better quality)", "1")], "0") - self._create_config_combo(form, "Text Encoder:", "text_encoder_quantization", [("16-bit (more RAM, better quality)", "bf16"), ("8-bit (less RAM)", "int8")], "int8") - self._create_config_combo(form, "VAE Precision:", "vae_precision", [("16-bit (faster, less VRAM)", "16"), ("32-bit (slower, better quality)", "32")], "16") - self._create_config_combo(form, "Compile Transformer:", "compile", [("On (requires Triton)", "transformer"), ("Off", "")], "") - self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl") - self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0) - self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1) - self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM) - self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100) - release_ram_btn = QPushButton("Force Release Models from RAM") - release_ram_btn.clicked.connect(self._on_release_ram) - form.addRow(release_ram_btn) - return tab - - def _create_extensions_config_tab(self): - tab, form = self._create_scrollable_form_tab() - self._create_config_combo(form, "Prompt Enhancer:", "enhancer_enabled", [("Off", 0), ("Florence 2 + Llama 3.2", 1), ("Florence 2 + Joy Caption (uncensored)", 2)], 0) - self._create_config_combo(form, "Enhancer Mode:", "enhancer_mode", [("Automatic on Generate", 0), ("On Demand Only", 1)], 0) - self._create_config_combo(form, "MMAudio:", "mmaudio_enabled", [("Off", 0), ("Enabled (unloaded after use)", 1), ("Enabled (persistent in RAM)", 2)], 0) - return tab - - def _create_outputs_config_tab(self): - tab, form = self._create_scrollable_form_tab() - self._create_config_combo(form, "Video Codec:", "video_output_codec", [("x265 Balanced", 'libx265_28'), ("x264 Balanced", 'libx264_8'), ("x265 High Quality", 'libx265_8'), ("x264 High Quality", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], 'libx264_8') - self._create_config_combo(form, "Image Codec:", "image_output_codec", [("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], 'jpeg_95') - self._create_config_textbox(form, "Video Output Folder:", "save_path", "outputs") - self._create_config_textbox(form, "Image Output Folder:", "image_save_path", "outputs") - return tab - - def _create_notifications_config_tab(self): - tab, form = self._create_scrollable_form_tab() - self._create_config_combo(form, "Notification Sound:", "notification_sound_enabled", [("On", 1), ("Off", 0)], 0) - self._create_config_slider(form, "Sound Volume:", "notification_sound_volume", 0, 100, 50, 5) - return tab - - def init_wgp_state(self): - wgp = self.wgp - initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type) - dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types - _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) - all_model_ids = [m[1] for m in all_models] - if initial_model not in all_model_ids: initial_model = wgp.transformer_type - state_dict = {} - state_dict["model_filename"] = wgp.get_model_filename(initial_model, wgp.transformer_quantization, wgp.transformer_dtype_policy) - state_dict["model_type"] = initial_model - state_dict["advanced"] = wgp.advanced - state_dict["last_model_per_family"] = wgp.server_config.get("last_model_per_family", {}) - state_dict["last_model_per_type"] = wgp.server_config.get("last_model_per_type", {}) - state_dict["last_resolution_per_group"] = wgp.server_config.get("last_resolution_per_group", {}) - state_dict["gen"] = {"queue": []} - self.state = state_dict - self.advanced_group.setChecked(wgp.advanced) - self.update_model_dropdowns(initial_model) - self.refresh_ui_from_model_change(initial_model) - self._update_input_visibility() - - def update_model_dropdowns(self, current_model_type): - wgp = self.wgp - family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type) - for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]: - combo = self.widgets[combo_name] - combo.blockSignals(True) - combo.clear() - if mock.choices: - for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key) - index = combo.findData(mock.value) - if index != -1: combo.setCurrentIndex(index) - - is_visible = True - if hasattr(mock, 'kwargs') and isinstance(mock.kwargs, dict): - is_visible = mock.kwargs.get('visible', True) - elif hasattr(mock, 'visible'): - is_visible = mock.visible - combo.setVisible(is_visible) - - combo.blockSignals(False) - - def refresh_ui_from_model_change(self, model_type): - """Update UI controls with default settings when the model is changed.""" - wgp = self.wgp - self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode)) - ui_defaults = wgp.get_default_settings(model_type) - wgp.set_model_settings(self.state, model_type, ui_defaults) - - model_def = wgp.get_model_def(model_type) - base_model_type = wgp.get_base_model_type(model_type) - model_filename = self.state.get('model_filename', '') - - image_outputs = model_def.get("image_outputs", False) - vace = wgp.test_vace_module(model_type) - t2v = base_model_type in ['t2v', 't2v_2_2'] - i2v = wgp.test_class_i2v(model_type) - fantasy = base_model_type in ["fantasy"] - multitalk = model_def.get("multitalk_class", False) - any_audio_guidance = fantasy or multitalk - sliding_window_enabled = wgp.test_any_sliding_window(model_type) - recammaster = base_model_type in ["recam_1.3B"] - ltxv = "ltxv" in model_filename - diffusion_forcing = "diffusion_forcing" in model_filename - any_skip_layer_guidance = model_def.get("skip_layer_guidance", False) - any_cfg_zero = model_def.get("cfg_zero", False) - any_cfg_star = model_def.get("cfg_star", False) - any_apg = model_def.get("adaptive_projected_guidance", False) - v2i_switch_supported = model_def.get("v2i_switch_supported", False) - - self._update_generation_mode_visibility(model_def) - - for widget in self.widgets.values(): - if hasattr(widget, 'blockSignals'): widget.blockSignals(True) - - self.widgets['prompt'].setText(ui_defaults.get("prompt", "")) - self.widgets['negative_prompt'].setText(ui_defaults.get("negative_prompt", "")) - self.widgets['seed'].setText(str(ui_defaults.get("seed", -1))) - - video_length_val = ui_defaults.get("video_length", 81) - self.widgets['video_length'].setValue(video_length_val) - self.widgets['video_length_label'].setText(str(video_length_val)) - - steps_val = ui_defaults.get("num_inference_steps", 30) - self.widgets['num_inference_steps'].setValue(steps_val) - self.widgets['num_inference_steps_label'].setText(str(steps_val)) - - self.widgets['resolution_group'].blockSignals(True) - self.widgets['resolution'].blockSignals(True) - - current_res_choice = ui_defaults.get("resolution") - model_resolutions = model_def.get("resolutions", None) - self.full_resolution_choices, current_res_choice = wgp.get_resolution_choices(current_res_choice, model_resolutions) - available_groups, selected_group_resolutions, selected_group = wgp.group_resolutions(model_def, self.full_resolution_choices, current_res_choice) - - self.widgets['resolution_group'].clear() - self.widgets['resolution_group'].addItems(available_groups) - group_index = self.widgets['resolution_group'].findText(selected_group) - if group_index != -1: - self.widgets['resolution_group'].setCurrentIndex(group_index) - - self.widgets['resolution'].clear() - for label, value in selected_group_resolutions: - self.widgets['resolution'].addItem(label, value) - res_index = self.widgets['resolution'].findData(current_res_choice) - if res_index != -1: - self.widgets['resolution'].setCurrentIndex(res_index) - - self.widgets['resolution_group'].blockSignals(False) - self.widgets['resolution'].blockSignals(False) - - for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'image_refs', 'audio_source']: - if name in self.widgets: self.widgets[name].clear() - - guidance_layout = self.widgets['guidance_layout'] - guidance_max = model_def.get("guidance_max_phases", 1) - guidance_layout.setRowVisible(self.widgets['guidance_phases_row_index'], guidance_max > 1) - - adv_general_layout = self.widgets['adv_general_layout'] - adv_general_layout.setRowVisible(self.widgets['flow_shift_row_index'], not image_outputs) - adv_general_layout.setRowVisible(self.widgets['audio_guidance_row_index'], any_audio_guidance) - adv_general_layout.setRowVisible(self.widgets['repeat_generation_row_index'], not image_outputs) - adv_general_layout.setRowVisible(self.widgets['multi_images_gen_type_row_index'], i2v) - - self.widgets['slg_group'].setVisible(any_skip_layer_guidance) - quality_form_layout = self.widgets['quality_form_layout'] - quality_form_layout.setRowVisible(self.widgets['apg_switch_row_index'], any_apg) - quality_form_layout.setRowVisible(self.widgets['cfg_star_switch_row_index'], any_cfg_star) - quality_form_layout.setRowVisible(self.widgets['cfg_zero_step_row_index'], any_cfg_zero) - quality_form_layout.setRowVisible(self.widgets['min_frames_if_references_row_index'], v2i_switch_supported and image_outputs) - - self.widgets['advanced_tabs'].setTabVisible(self.widgets['sliding_window_tab_index'], sliding_window_enabled and not image_outputs) - - misc_layout = self.widgets['misc_layout'] - misc_layout.setRowVisible(self.widgets['riflex_row_index'], not (recammaster or ltxv or diffusion_forcing)) - - index = self.widgets['multi_images_gen_type'].findData(ui_defaults.get('multi_images_gen_type', 0)) - if index != -1: self.widgets['multi_images_gen_type'].setCurrentIndex(index) - - guidance_val = ui_defaults.get("guidance_scale", 5.0) - self.widgets['guidance_scale'].setValue(int(guidance_val * 10)) - self.widgets['guidance_scale_label'].setText(f"{guidance_val:.1f}") - - guidance2_val = ui_defaults.get("guidance2_scale", 5.0) - self.widgets['guidance2_scale'].setValue(int(guidance2_val * 10)) - self.widgets['guidance2_scale_label'].setText(f"{guidance2_val:.1f}") - - guidance3_val = ui_defaults.get("guidance3_scale", 5.0) - self.widgets['guidance3_scale'].setValue(int(guidance3_val * 10)) - self.widgets['guidance3_scale_label'].setText(f"{guidance3_val:.1f}") - - self.widgets['guidance_phases'].clear() - if guidance_max >= 1: self.widgets['guidance_phases'].addItem("One Phase", 1) - if guidance_max >= 2: self.widgets['guidance_phases'].addItem("Two Phases", 2) - if guidance_max >= 3: self.widgets['guidance_phases'].addItem("Three Phases", 3) - index = self.widgets['guidance_phases'].findData(ui_defaults.get("guidance_phases", 1)) - if index != -1: self.widgets['guidance_phases'].setCurrentIndex(index) - - switch_thresh_val = ui_defaults.get("switch_threshold", 0) - self.widgets['switch_threshold'].setValue(switch_thresh_val) - self.widgets['switch_threshold_label'].setText(str(switch_thresh_val)) - - nag_scale_val = ui_defaults.get('NAG_scale', 1.0) - self.widgets['NAG_scale'].setValue(int(nag_scale_val * 10)) - self.widgets['NAG_scale_label'].setText(f"{nag_scale_val:.1f}") - - nag_tau_val = ui_defaults.get('NAG_tau', 3.5) - self.widgets['NAG_tau'].setValue(int(nag_tau_val * 10)) - self.widgets['NAG_tau_label'].setText(f"{nag_tau_val:.1f}") - - nag_alpha_val = ui_defaults.get('NAG_alpha', 0.5) - self.widgets['NAG_alpha'].setValue(int(nag_alpha_val * 10)) - self.widgets['NAG_alpha_label'].setText(f"{nag_alpha_val:.1f}") - - self.widgets['nag_group'].setVisible(vace or t2v or i2v) - - self.widgets['sample_solver'].clear() - sampler_choices = model_def.get("sample_solvers", []) - self.widgets['solver_row_container'].setVisible(bool(sampler_choices)) - if sampler_choices: - for label, value in sampler_choices: self.widgets['sample_solver'].addItem(label, value) - solver_val = ui_defaults.get('sample_solver', sampler_choices[0][1]) - index = self.widgets['sample_solver'].findData(solver_val) - if index != -1: self.widgets['sample_solver'].setCurrentIndex(index) - - flow_val = ui_defaults.get("flow_shift", 3.0) - self.widgets['flow_shift'].setValue(int(flow_val * 10)) - self.widgets['flow_shift_label'].setText(f"{flow_val:.1f}") - - audio_guidance_val = ui_defaults.get("audio_guidance_scale", 4.0) - self.widgets['audio_guidance_scale'].setValue(int(audio_guidance_val * 10)) - self.widgets['audio_guidance_scale_label'].setText(f"{audio_guidance_val:.1f}") - - repeat_val = ui_defaults.get("repeat_generation", 1) - self.widgets['repeat_generation'].setValue(repeat_val) - self.widgets['repeat_generation_label'].setText(str(repeat_val)) - - available_loras, _, _, _, _, _ = wgp.setup_loras(model_type, None, wgp.get_lora_dir(model_type), "") - self.state['loras'] = available_loras - self.lora_map = {os.path.basename(p): p for p in available_loras} - lora_list_widget = self.widgets['activated_loras'] - lora_list_widget.clear() - lora_list_widget.addItems(sorted(self.lora_map.keys())) - selected_loras = ui_defaults.get('activated_loras', []) - for i in range(lora_list_widget.count()): - item = lora_list_widget.item(i) - if any(item.text() == os.path.basename(p) for p in selected_loras): item.setSelected(True) - self.widgets['loras_multipliers'].setText(ui_defaults.get('loras_multipliers', '')) - - skip_cache_val = ui_defaults.get('skip_steps_cache_type', "") - index = self.widgets['skip_steps_cache_type'].findData(skip_cache_val) - if index != -1: self.widgets['skip_steps_cache_type'].setCurrentIndex(index) - - skip_mult = ui_defaults.get('skip_steps_multiplier', 1.5) - index = self.widgets['skip_steps_multiplier'].findData(skip_mult) - if index != -1: self.widgets['skip_steps_multiplier'].setCurrentIndex(index) - - skip_perc_val = ui_defaults.get('skip_steps_start_step_perc', 0) - self.widgets['skip_steps_start_step_perc'].setValue(skip_perc_val) - self.widgets['skip_steps_start_step_perc_label'].setText(str(skip_perc_val)) - - temp_up_val = ui_defaults.get('temporal_upsampling', "") - index = self.widgets['temporal_upsampling'].findData(temp_up_val) - if index != -1: self.widgets['temporal_upsampling'].setCurrentIndex(index) - - spat_up_val = ui_defaults.get('spatial_upsampling', "") - index = self.widgets['spatial_upsampling'].findData(spat_up_val) - if index != -1: self.widgets['spatial_upsampling'].setCurrentIndex(index) - - film_grain_i = ui_defaults.get('film_grain_intensity', 0) - self.widgets['film_grain_intensity'].setValue(int(film_grain_i * 100)) - self.widgets['film_grain_intensity_label'].setText(f"{film_grain_i:.2f}") - - film_grain_s = ui_defaults.get('film_grain_saturation', 0.5) - self.widgets['film_grain_saturation'].setValue(int(film_grain_s * 100)) - self.widgets['film_grain_saturation_label'].setText(f"{film_grain_s:.2f}") - - self.widgets['MMAudio_setting'].setCurrentIndex(ui_defaults.get('MMAudio_setting', 0)) - self.widgets['MMAudio_prompt'].setText(ui_defaults.get('MMAudio_prompt', '')) - self.widgets['MMAudio_neg_prompt'].setText(ui_defaults.get('MMAudio_neg_prompt', '')) - - self.widgets['slg_switch'].setCurrentIndex(ui_defaults.get('slg_switch', 0)) - slg_start_val = ui_defaults.get('slg_start_perc', 10) - self.widgets['slg_start_perc'].setValue(slg_start_val) - self.widgets['slg_start_perc_label'].setText(str(slg_start_val)) - slg_end_val = ui_defaults.get('slg_end_perc', 90) - self.widgets['slg_end_perc'].setValue(slg_end_val) - self.widgets['slg_end_perc_label'].setText(str(slg_end_val)) - - self.widgets['apg_switch'].setCurrentIndex(ui_defaults.get('apg_switch', 0)) - self.widgets['cfg_star_switch'].setCurrentIndex(ui_defaults.get('cfg_star_switch', 0)) - - cfg_zero_val = ui_defaults.get('cfg_zero_step', -1) - self.widgets['cfg_zero_step'].setValue(cfg_zero_val) - self.widgets['cfg_zero_step_label'].setText(str(cfg_zero_val)) - - min_frames_val = ui_defaults.get('min_frames_if_references', 1) - index = self.widgets['min_frames_if_references'].findData(min_frames_val) - if index != -1: self.widgets['min_frames_if_references'].setCurrentIndex(index) - - self.widgets['RIFLEx_setting'].setCurrentIndex(ui_defaults.get('RIFLEx_setting', 0)) - - fps = wgp.get_model_fps(model_type) - force_fps_choices = [ - (f"Model Default ({fps} fps)", ""), ("Auto", "auto"), ("Control Video fps", "control"), - ("Source Video fps", "source"), ("15", "15"), ("16", "16"), ("23", "23"), - ("24", "24"), ("25", "25"), ("30", "30") - ] - self.widgets['force_fps'].clear() - for label, value in force_fps_choices: self.widgets['force_fps'].addItem(label, value) - force_fps_val = ui_defaults.get('force_fps', "") - index = self.widgets['force_fps'].findData(force_fps_val) - if index != -1: self.widgets['force_fps'].setCurrentIndex(index) - - override_prof_val = ui_defaults.get('override_profile', -1) - index = self.widgets['override_profile'].findData(override_prof_val) - if index != -1: self.widgets['override_profile'].setCurrentIndex(index) - - self.widgets['multi_prompts_gen_type'].setCurrentIndex(ui_defaults.get('multi_prompts_gen_type', 0)) - - denoising_val = ui_defaults.get("denoising_strength", 0.5) - self.widgets['denoising_strength'].setValue(int(denoising_val * 100)) - self.widgets['denoising_strength_label'].setText(f"{denoising_val:.2f}") - - sw_size = ui_defaults.get("sliding_window_size", 129) - self.widgets['sliding_window_size'].setValue(sw_size) - self.widgets['sliding_window_size_label'].setText(str(sw_size)) - - sw_overlap = ui_defaults.get("sliding_window_overlap", 5) - self.widgets['sliding_window_overlap'].setValue(sw_overlap) - self.widgets['sliding_window_overlap_label'].setText(str(sw_overlap)) - - sw_color = ui_defaults.get("sliding_window_color_correction_strength", 0) - self.widgets['sliding_window_color_correction_strength'].setValue(int(sw_color * 100)) - self.widgets['sliding_window_color_correction_strength_label'].setText(f"{sw_color:.2f}") - - sw_noise = ui_defaults.get("sliding_window_overlap_noise", 20) - self.widgets['sliding_window_overlap_noise'].setValue(sw_noise) - self.widgets['sliding_window_overlap_noise_label'].setText(str(sw_noise)) - - sw_discard = ui_defaults.get("sliding_window_discard_last_frames", 0) - self.widgets['sliding_window_discard_last_frames'].setValue(sw_discard) - self.widgets['sliding_window_discard_last_frames_label'].setText(str(sw_discard)) - - for widget in self.widgets.values(): - if hasattr(widget, 'blockSignals'): widget.blockSignals(False) - - self._update_dynamic_ui() - self._update_input_visibility() - - def _update_dynamic_ui(self): - phases = self.widgets['guidance_phases'].currentData() or 1 - guidance_layout = self.widgets['guidance_layout'] - guidance_layout.setRowVisible(self.widgets['guidance2_row_index'], phases >= 2) - guidance_layout.setRowVisible(self.widgets['guidance3_row_index'], phases >= 3) - guidance_layout.setRowVisible(self.widgets['switch_thresh_row_index'], phases >= 2) - - def _update_generation_mode_visibility(self, model_def): - allowed = model_def.get("image_prompt_types_allowed", "") - choices = [] - if "T" in allowed or not allowed: choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T")) - if "S" in allowed: choices.append(("Start Video with Image", "S")) - if "V" in allowed: choices.append(("Continue Video", "V")) - if "L" in allowed: choices.append(("Continue Last Video", "L")) - button_map = { "T": self.widgets['mode_t'], "S": self.widgets['mode_s'], "V": self.widgets['mode_v'], "L": self.widgets['mode_l'] } - for btn in button_map.values(): btn.setVisible(False) - allowed_values = [c[1] for c in choices] - for label, value in choices: - if value in button_map: - btn = button_map[value] - btn.setText(label) - btn.setVisible(True) - current_checked_value = next((value for value, btn in button_map.items() if btn.isChecked()), None) - if current_checked_value is None or not button_map[current_checked_value].isVisible(): - if allowed_values: button_map[allowed_values[0]].setChecked(True) - end_image_visible = "E" in allowed - self.widgets['image_end_checkbox'].setVisible(end_image_visible) - if not end_image_visible: self.widgets['image_end_checkbox'].setChecked(False) - control_video_visible = model_def.get("guide_preprocessing") is not None - self.widgets['control_video_checkbox'].setVisible(control_video_visible) - if not control_video_visible: self.widgets['control_video_checkbox'].setChecked(False) - ref_image_visible = model_def.get("image_ref_choices") is not None - self.widgets['ref_image_checkbox'].setVisible(ref_image_visible) - if not ref_image_visible: self.widgets['ref_image_checkbox'].setChecked(False) - - def _update_input_visibility(self): - is_s_mode = self.widgets['mode_s'].isChecked() - is_v_mode = self.widgets['mode_v'].isChecked() - is_l_mode = self.widgets['mode_l'].isChecked() - use_end = self.widgets['image_end_checkbox'].isChecked() and self.widgets['image_end_checkbox'].isVisible() - use_control = self.widgets['control_video_checkbox'].isChecked() and self.widgets['control_video_checkbox'].isVisible() - use_ref = self.widgets['ref_image_checkbox'].isChecked() and self.widgets['ref_image_checkbox'].isVisible() - self.widgets['image_start_container'].setVisible(is_s_mode) - self.widgets['video_source_container'].setVisible(is_v_mode) - end_checkbox_enabled = is_s_mode or is_v_mode or is_l_mode - self.widgets['image_end_checkbox'].setEnabled(end_checkbox_enabled) - self.widgets['image_end_container'].setVisible(use_end and end_checkbox_enabled) - self.widgets['video_guide_container'].setVisible(use_control) - self.widgets['video_mask_container'].setVisible(use_control) - self.widgets['image_refs_container'].setVisible(use_ref) - - def connect_signals(self): - self.widgets['model_family'].currentIndexChanged.connect(self._on_family_changed) - self.widgets['model_base_type_choice'].currentIndexChanged.connect(self._on_base_type_changed) - self.widgets['model_choice'].currentIndexChanged.connect(self._on_model_changed) - self.widgets['resolution_group'].currentIndexChanged.connect(self._on_resolution_group_changed) - self.widgets['guidance_phases'].currentIndexChanged.connect(self._update_dynamic_ui) - self.widgets['mode_t'].toggled.connect(self._update_input_visibility) - self.widgets['mode_s'].toggled.connect(self._update_input_visibility) - self.widgets['mode_v'].toggled.connect(self._update_input_visibility) - self.widgets['mode_l'].toggled.connect(self._update_input_visibility) - self.widgets['image_end_checkbox'].toggled.connect(self._update_input_visibility) - self.widgets['control_video_checkbox'].toggled.connect(self._update_input_visibility) - self.widgets['ref_image_checkbox'].toggled.connect(self._update_input_visibility) - self.widgets['preview_group'].toggled.connect(self._on_preview_toggled) - self.generate_btn.clicked.connect(self._on_generate) - self.add_to_queue_btn.clicked.connect(self._on_add_to_queue) - self.remove_queue_btn.clicked.connect(self._on_remove_selected_from_queue) - self.clear_queue_btn.clicked.connect(self._on_clear_queue) - self.abort_btn.clicked.connect(self._on_abort) - self.queue_table.rowsMoved.connect(self._on_queue_rows_moved) - - def load_main_config(self): - try: - with open('main_config.json', 'r') as f: self.main_config = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): self.main_config = {'preview_visible': False} - - def save_main_config(self): - try: - with open('main_config.json', 'w') as f: json.dump(self.main_config, f, indent=4) - except Exception as e: print(f"Error saving main_config.json: {e}") - - def apply_initial_config(self): - is_visible = self.main_config.get('preview_visible', True) - self.widgets['preview_group'].setChecked(is_visible) - self.widgets['preview_image'].setVisible(is_visible) - - def _on_preview_toggled(self, checked): - self.widgets['preview_image'].setVisible(checked) - self.main_config['preview_visible'] = checked - self.save_main_config() - - def _on_family_changed(self): - family = self.widgets['model_family'].currentData() - if not family or not self.state: return - base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family) - - if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict): - is_visible_base = base_type_mock.kwargs.get('visible', True) - elif hasattr(base_type_mock, 'visible'): - is_visible_base = base_type_mock.visible - else: - is_visible_base = True - - self.widgets['model_base_type_choice'].blockSignals(True) - self.widgets['model_base_type_choice'].clear() - if base_type_mock.choices: - for label, value in base_type_mock.choices: self.widgets['model_base_type_choice'].addItem(label, value) - self.widgets['model_base_type_choice'].setCurrentIndex(self.widgets['model_base_type_choice'].findData(base_type_mock.value)) - self.widgets['model_base_type_choice'].setVisible(is_visible_base) - self.widgets['model_base_type_choice'].blockSignals(False) - - if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict): - is_visible_choice = choice_mock.kwargs.get('visible', True) - elif hasattr(choice_mock, 'visible'): - is_visible_choice = choice_mock.visible - else: - is_visible_choice = True - - self.widgets['model_choice'].blockSignals(True) - self.widgets['model_choice'].clear() - if choice_mock.choices: - for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) - self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value)) - self.widgets['model_choice'].setVisible(is_visible_choice) - self.widgets['model_choice'].blockSignals(False) - - self._on_model_changed() - - def _on_base_type_changed(self): - family = self.widgets['model_family'].currentData() - base_type = self.widgets['model_base_type_choice'].currentData() - if not family or not base_type or not self.state: return - base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type) - - if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict): - is_visible_choice = choice_mock.kwargs.get('visible', True) - elif hasattr(choice_mock, 'visible'): - is_visible_choice = choice_mock.visible - else: - is_visible_choice = True - - self.widgets['model_choice'].blockSignals(True) - self.widgets['model_choice'].clear() - if choice_mock.choices: - for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) - self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value)) - self.widgets['model_choice'].setVisible(is_visible_choice) - self.widgets['model_choice'].blockSignals(False) - self._on_model_changed() - - def _on_model_changed(self): - model_type = self.widgets['model_choice'].currentData() - if not model_type or model_type == self.state.get('model_type'): return - self.wgp.change_model(self.state, model_type) - self.refresh_ui_from_model_change(model_type) - - def _on_resolution_group_changed(self): - selected_group = self.widgets['resolution_group'].currentText() - if not selected_group or not hasattr(self, 'full_resolution_choices'): return - model_type = self.state['model_type'] - model_def = self.wgp.get_model_def(model_type) - model_resolutions = model_def.get("resolutions", None) - group_resolution_choices = [] - if model_resolutions is None: - group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group] - else: return - last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "") - if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices: - last_resolution = group_resolution_choices[0][1] - self.widgets['resolution'].blockSignals(True) - self.widgets['resolution'].clear() - for label, value in group_resolution_choices: self.widgets['resolution'].addItem(label, value) - self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution)) - self.widgets['resolution'].blockSignals(False) - - def collect_inputs(self): - full_inputs = self.wgp.get_current_model_settings(self.state).copy() - full_inputs['lset_name'] = "" - full_inputs['image_mode'] = 0 - expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, } - for key, default_value in expected_keys.items(): - if key not in full_inputs: full_inputs[key] = default_value - full_inputs['prompt'] = self.widgets['prompt'].toPlainText() - full_inputs['negative_prompt'] = self.widgets['negative_prompt'].toPlainText() - full_inputs['resolution'] = self.widgets['resolution'].currentData() - full_inputs['video_length'] = self.widgets['video_length'].value() - full_inputs['num_inference_steps'] = self.widgets['num_inference_steps'].value() - full_inputs['seed'] = int(self.widgets['seed'].text()) - image_prompt_type = "" - video_prompt_type = "" - if self.widgets['mode_s'].isChecked(): image_prompt_type = 'S' - elif self.widgets['mode_v'].isChecked(): image_prompt_type = 'V' - elif self.widgets['mode_l'].isChecked(): image_prompt_type = 'L' - if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): image_prompt_type += 'E' - if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): video_prompt_type += 'V' - if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): video_prompt_type += 'I' - full_inputs['image_prompt_type'] = image_prompt_type - full_inputs['video_prompt_type'] = video_prompt_type - for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']: - if name in self.widgets: full_inputs[name] = self.widgets[name].text() or None - paths = self.widgets['image_refs'].text().split(';') - full_inputs['image_refs'] = [p.strip() for p in paths if p.strip()] if paths and paths[0] else None - full_inputs['denoising_strength'] = self.widgets['denoising_strength'].value() / 100.0 - if self.advanced_group.isChecked(): - full_inputs['guidance_scale'] = self.widgets['guidance_scale'].value() / 10.0 - full_inputs['guidance_phases'] = self.widgets['guidance_phases'].currentData() - full_inputs['guidance2_scale'] = self.widgets['guidance2_scale'].value() / 10.0 - full_inputs['guidance3_scale'] = self.widgets['guidance3_scale'].value() / 10.0 - full_inputs['switch_threshold'] = self.widgets['switch_threshold'].value() - full_inputs['NAG_scale'] = self.widgets['NAG_scale'].value() / 10.0 - full_inputs['NAG_tau'] = self.widgets['NAG_tau'].value() / 10.0 - full_inputs['NAG_alpha'] = self.widgets['NAG_alpha'].value() / 10.0 - full_inputs['sample_solver'] = self.widgets['sample_solver'].currentData() - full_inputs['flow_shift'] = self.widgets['flow_shift'].value() / 10.0 - full_inputs['audio_guidance_scale'] = self.widgets['audio_guidance_scale'].value() / 10.0 - full_inputs['repeat_generation'] = self.widgets['repeat_generation'].value() - full_inputs['multi_images_gen_type'] = self.widgets['multi_images_gen_type'].currentData() - selected_items = self.widgets['activated_loras'].selectedItems() - full_inputs['activated_loras'] = [self.lora_map[item.text()] for item in selected_items if item.text() in self.lora_map] - full_inputs['loras_multipliers'] = self.widgets['loras_multipliers'].toPlainText() - full_inputs['skip_steps_cache_type'] = self.widgets['skip_steps_cache_type'].currentData() - full_inputs['skip_steps_multiplier'] = self.widgets['skip_steps_multiplier'].currentData() - full_inputs['skip_steps_start_step_perc'] = self.widgets['skip_steps_start_step_perc'].value() - full_inputs['temporal_upsampling'] = self.widgets['temporal_upsampling'].currentData() - full_inputs['spatial_upsampling'] = self.widgets['spatial_upsampling'].currentData() - full_inputs['film_grain_intensity'] = self.widgets['film_grain_intensity'].value() / 100.0 - full_inputs['film_grain_saturation'] = self.widgets['film_grain_saturation'].value() / 100.0 - full_inputs['MMAudio_setting'] = self.widgets['MMAudio_setting'].currentData() - full_inputs['MMAudio_prompt'] = self.widgets['MMAudio_prompt'].text() - full_inputs['MMAudio_neg_prompt'] = self.widgets['MMAudio_neg_prompt'].text() - full_inputs['RIFLEx_setting'] = self.widgets['RIFLEx_setting'].currentData() - full_inputs['force_fps'] = self.widgets['force_fps'].currentData() - full_inputs['override_profile'] = self.widgets['override_profile'].currentData() - full_inputs['multi_prompts_gen_type'] = self.widgets['multi_prompts_gen_type'].currentData() - full_inputs['slg_switch'] = self.widgets['slg_switch'].currentData() - full_inputs['slg_start_perc'] = self.widgets['slg_start_perc'].value() - full_inputs['slg_end_perc'] = self.widgets['slg_end_perc'].value() - full_inputs['apg_switch'] = self.widgets['apg_switch'].currentData() - full_inputs['cfg_star_switch'] = self.widgets['cfg_star_switch'].currentData() - full_inputs['cfg_zero_step'] = self.widgets['cfg_zero_step'].value() - full_inputs['min_frames_if_references'] = self.widgets['min_frames_if_references'].currentData() - full_inputs['sliding_window_size'] = self.widgets['sliding_window_size'].value() - full_inputs['sliding_window_overlap'] = self.widgets['sliding_window_overlap'].value() - full_inputs['sliding_window_color_correction_strength'] = self.widgets['sliding_window_color_correction_strength'].value() / 100.0 - full_inputs['sliding_window_overlap_noise'] = self.widgets['sliding_window_overlap_noise'].value() - full_inputs['sliding_window_discard_last_frames'] = self.widgets['sliding_window_discard_last_frames'].value() - return full_inputs - - def _prepare_state_for_generation(self): - if 'gen' in self.state: - self.state['gen'].pop('abort', None) - self.state['gen'].pop('in_progress', None) - - def _on_generate(self): - try: - is_running = self.thread and self.thread.isRunning() - self._add_task_to_queue() - if not is_running: self.start_generation() - except Exception as e: - import traceback; traceback.print_exc() - - def _on_add_to_queue(self): - try: - self._add_task_to_queue() - except Exception as e: - import traceback; traceback.print_exc() - - def _add_task_to_queue(self): - all_inputs = self.collect_inputs() - for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None) - all_inputs['state'] = self.state - self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs) - self.state["validate_success"] = 1 - self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type']) - self.update_queue_table() - - def start_generation(self): - if not self.state['gen']['queue']: return - self._prepare_state_for_generation() - self.generate_btn.setEnabled(False) - self.add_to_queue_btn.setEnabled(True) - self.thread = QThread() - self.worker = Worker(self.plugin, self.state) - self.worker.moveToThread(self.thread) - self.thread.started.connect(self.worker.run) - self.worker.finished.connect(self.thread.quit) - self.worker.finished.connect(self.worker.deleteLater) - self.thread.finished.connect(self.thread.deleteLater) - self.thread.finished.connect(self.on_generation_finished) - self.worker.status.connect(self.status_label.setText) - self.worker.progress.connect(self.update_progress) - self.worker.preview.connect(self.update_preview) - self.worker.output.connect(self.update_queue_and_results) - self.worker.error.connect(self.on_generation_error) - self.thread.start() - self.update_queue_table() - - def on_generation_finished(self): - time.sleep(0.1) - self.status_label.setText("Finished.") - self.progress_bar.setValue(0) - self.generate_btn.setEnabled(True) - self.add_to_queue_btn.setEnabled(False) - self.thread = None; self.worker = None - self.update_queue_table() - - def on_generation_error(self, err_msg): - QMessageBox.critical(self, "Generation Error", str(err_msg)) - self.on_generation_finished() - - def update_progress(self, data): - if len(data) > 1 and isinstance(data[0], tuple): - step, total = data[0] - self.progress_bar.setMaximum(total) - self.progress_bar.setValue(step) - self.status_label.setText(str(data[1])) - if step <= 1: self.update_queue_table() - elif len(data) > 1: self.status_label.setText(str(data[1])) - - def update_preview(self, pil_image): - if pil_image and self.widgets['preview_group'].isChecked(): - q_image = ImageQt(pil_image) - pixmap = QPixmap.fromImage(q_image) - self.preview_image.setPixmap(pixmap.scaled(self.preview_image.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)) - - def update_queue_and_results(self): - self.update_queue_table() - file_list = self.state.get('gen', {}).get('file_list', []) - for file_path in file_list: - if file_path not in self.processed_files: - self.add_result_item(file_path) - self.processed_files.add(file_path) - - def add_result_item(self, video_path): - item_widget = VideoResultItemWidget(video_path, self.plugin) - list_item = QListWidgetItem(self.results_list) - list_item.setSizeHint(item_widget.sizeHint()) - self.results_list.addItem(list_item) - self.results_list.setItemWidget(list_item, item_widget) - - def update_queue_table(self): - with self.wgp.lock: - queue = self.state.get('gen', {}).get('queue', []) - is_running = self.thread and self.thread.isRunning() - queue_to_display = queue if is_running else [None] + queue - table_data = self.wgp.get_queue_table(queue_to_display) - self.queue_table.setRowCount(0) - self.queue_table.setRowCount(len(table_data)) - self.queue_table.setColumnCount(4) - self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"]) - for row_idx, row_data in enumerate(table_data): - prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1]) - for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]): - self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data))) - self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) - self.queue_table.resizeColumnsToContents() - - def _on_remove_selected_from_queue(self): - selected_row = self.queue_table.currentRow() - if selected_row < 0: return - with self.wgp.lock: - is_running = self.thread and self.thread.isRunning() - offset = 1 if is_running else 0 - queue = self.state.get('gen', {}).get('queue', []) - if len(queue) > selected_row + offset: queue.pop(selected_row + offset) - self.update_queue_table() - - def _on_queue_rows_moved(self, source_row, dest_row): - with self.wgp.lock: - queue = self.state.get('gen', {}).get('queue', []) - is_running = self.thread and self.thread.isRunning() - offset = 1 if is_running else 0 - real_source_idx = source_row + offset - real_dest_idx = dest_row + offset - moved_item = queue.pop(real_source_idx) - queue.insert(real_dest_idx, moved_item) - self.update_queue_table() - - def _on_clear_queue(self): - self.wgp.clear_queue_action(self.state) - self.update_queue_table() - - def _on_abort(self): - if self.worker: - self.wgp.abort_generation(self.state) - self.status_label.setText("Aborting...") - self.worker._is_running = False - - def _on_release_ram(self): - self.wgp.release_RAM() - QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.") - - def _on_apply_config_changes(self): - changes = {} - list_widget = self.widgets['config_transformer_types'] - changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] - list_widget = self.widgets['config_preload_model_policy'] - changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] - changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData() - changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText() - for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]: - changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]: - changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value() - for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]: - changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]: - widget = self.widgets[f'config_{key}'] - changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text() - changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData() - changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value() - changes['last_resolution_choice'] = self.widgets['resolution'].currentData() - try: - msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes) - self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.") - self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode)) - if family_mock.choices is not None or choice_mock.choices is not None: - self.update_model_dropdowns(self.wgp.transformer_type) - self.refresh_ui_from_model_change(self.wgp.transformer_type) - except Exception as e: - self.config_status_label.setText(f"Error applying changes: {e}") - import traceback; traceback.print_exc() - -class Plugin(VideoEditorPlugin): - def initialize(self): - self.name = "WanGP AI Generator" - self.description = "Uses the integrated WanGP library to generate video clips." - self.client_widget = WgpDesktopPluginWidget(self) - self.dock_widget = None - self.active_region = None; self.temp_dir = None - self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None - - def enable(self): - if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name) - self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu) - self.app.status_label.setText(f"{self.name}: Enabled.") - - def disable(self): - try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu) - except TypeError: pass - self._cleanup_temp_dir() - if self.client_widget.worker: self.client_widget._on_abort() - self.app.status_label.setText(f"{self.name}: Disabled.") - - def _cleanup_temp_dir(self): - if self.temp_dir and os.path.exists(self.temp_dir): - shutil.rmtree(self.temp_dir) - self.temp_dir = None - - def _reset_state(self): - self.active_region = None; self.insert_on_new_track = False - self.start_frame_path = None; self.end_frame_path = None - self.client_widget.processed_files.clear() - self.client_widget.results_list.clear() - self.client_widget.widgets['image_start'].clear() - self.client_widget.widgets['image_end'].clear() - self.client_widget.widgets['video_source'].clear() - self._cleanup_temp_dir() - self.app.status_label.setText(f"{self.name}: Ready.") - - def on_timeline_context_menu(self, menu, event): - region = self.app.timeline_widget.get_region_at_pos(event.pos()) - if region: - menu.addSeparator() - start_sec, end_sec = region - start_data, _, _ = self.app.get_frame_data_at_time(start_sec) - end_data, _, _ = self.app.get_frame_data_at_time(end_sec) - if start_data and end_data: - join_action = menu.addAction("Join Frames With WanGP") - join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) - join_action_new_track = menu.addAction("Join Frames With WanGP (New Track)") - join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) - create_action = menu.addAction("Create Video With WanGP") - create_action.triggered.connect(lambda: self.setup_creator_for_region(region)) - - def setup_generator_for_region(self, region, on_new_track=False): - self._reset_state() - self.active_region = region - self.insert_on_new_track = on_new_track - - model_to_set = 'i2v_2_2' - dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types - _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) - if any(model_to_set == m[1] for m in all_models): - if self.client_widget.state.get('model_type') != model_to_set: - self.client_widget.update_model_dropdowns(model_to_set) - self.client_widget._on_model_changed() - else: - print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") - - start_sec, end_sec = region - start_data, w, h = self.app.get_frame_data_at_time(start_sec) - end_data, _, _ = self.app.get_frame_data_at_time(end_sec) - if not start_data or not end_data: - QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.") - return - try: - self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_") - self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png") - self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png") - QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path) - QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) - - duration_sec = end_sec - start_sec - wgp = self.client_widget.wgp - model_type = self.client_widget.state['model_type'] - fps = wgp.get_model_fps(model_type) - video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) - - self.client_widget.widgets['video_length'].setValue(video_length_frames) - self.client_widget.widgets['mode_s'].setChecked(True) - self.client_widget.widgets['image_end_checkbox'].setChecked(True) - self.client_widget.widgets['image_start'].setText(self.start_frame_path) - self.client_widget.widgets['image_end'].setText(self.end_frame_path) - - self.client_widget._update_input_visibility() - - except Exception as e: - QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}") - self._cleanup_temp_dir() - return - self.app.status_label.setText(f"WanGP: Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.") - self.dock_widget.show() - self.dock_widget.raise_() - - def setup_creator_for_region(self, region): - self._reset_state() - self.active_region = region - self.insert_on_new_track = True - - model_to_set = 't2v_2_2' - dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types - _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) - if any(model_to_set == m[1] for m in all_models): - if self.client_widget.state.get('model_type') != model_to_set: - self.client_widget.update_model_dropdowns(model_to_set) - self.client_widget._on_model_changed() - else: - print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.") - - start_sec, end_sec = region - duration_sec = end_sec - start_sec - wgp = self.client_widget.wgp - model_type = self.client_widget.state['model_type'] - fps = wgp.get_model_fps(model_type) - video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) - - self.client_widget.widgets['video_length'].setValue(video_length_frames) - self.client_widget.widgets['mode_t'].setChecked(True) - - self.app.status_label.setText(f"WanGP: Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.") - self.dock_widget.show() - self.dock_widget.raise_() - - def insert_generated_clip(self, video_path): - from videoeditor import TimelineClip - if not self.active_region: - self.app.status_label.setText("WanGP Error: No active region to insert into."); return - if not os.path.exists(video_path): - self.app.status_label.setText(f"WanGP Error: Output file not found: {video_path}"); return - start_sec, end_sec = self.active_region - def complex_insertion_action(): - self.app._add_media_files_to_project([video_path]) - media_info = self.app.media_properties.get(video_path) - if not media_info: raise ValueError("Could not probe inserted clip.") - actual_duration, has_audio = media_info['duration'], media_info['has_audio'] - if self.insert_on_new_track: - self.app.timeline.num_video_tracks += 1 - video_track_index = self.app.timeline.num_video_tracks - audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None - else: - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) - clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec] - for clip in clips_to_remove: - if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip) - video_track_index, audio_track_index = 1, 1 if has_audio else None - group_id = str(uuid.uuid4()) - new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id) - self.app.timeline.add_clip(new_clip) - if audio_track_index: - if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index - audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id) - self.app.timeline.add_clip(audio_clip) - try: - self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action) - self.app.prune_empty_tracks() - self.app.status_label.setText("AI clip inserted successfully.") - for i in range(self.client_widget.results_list.count()): - widget = self.client_widget.results_list.itemWidget(self.client_widget.results_list.item(i)) - if widget and widget.video_path == video_path: - self.client_widget.results_list.takeItem(i); break - except Exception as e: - import traceback; traceback.print_exc() - self.app.status_label.setText(f"WanGP Error during clip insertion: {e}") \ No newline at end of file +import sys +import os +import threading +import time +import json +import tempfile +import shutil +import uuid +from unittest.mock import MagicMock +from pathlib import Path + +# --- Start of Gradio Hijacking --- +# This block creates a mock Gradio module. When wgp.py is imported, +# all calls to `gr.*` will be intercepted by these mock objects, +# preventing any UI from being built and allowing us to use the +# backend logic directly. + +class MockGradioComponent(MagicMock): + """A smarter mock that captures constructor arguments.""" + def __init__(self, *args, **kwargs): + super().__init__(name=f"gr.{kwargs.get('elem_id', 'component')}") + self.kwargs = kwargs + self.value = kwargs.get('value') + self.choices = kwargs.get('choices') + + for method in ['then', 'change', 'click', 'input', 'select', 'upload', 'mount', 'launch', 'on', 'release']: + setattr(self, method, lambda *a, **kw: self) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + +class MockGradioError(Exception): + pass + +class MockGradioModule: + def __getattr__(self, name): + if name == 'Error': + return lambda *args, **kwargs: MockGradioError(*args) + + if name in ['Info', 'Warning']: + return lambda *args, **kwargs: print(f"Intercepted gr.{name}:", *args) + + return lambda *args, **kwargs: MockGradioComponent(*args, **kwargs) + +sys.modules['gradio'] = MockGradioModule() +sys.modules['gradio.gallery'] = MockGradioModule() +sys.modules['shared.gradio.gallery'] = MockGradioModule() +# --- End of Gradio Hijacking --- + +import ffmpeg + +from PyQt6.QtWidgets import ( + QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, + QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox, + QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem, + QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem, + QMessageBox, QRadioButton, QSizePolicy +) +from PyQt6.QtCore import Qt, QThread, QObject, pyqtSignal, QUrl, QSize, QRectF +from PyQt6.QtGui import QPixmap, QImage, QDropEvent +from PyQt6.QtMultimedia import QMediaPlayer +from PyQt6.QtMultimediaWidgets import QVideoWidget +from PIL.ImageQt import ImageQt + +# Import the base plugin class from the main application's path +sys.path.append(str(Path(__file__).parent.parent.parent)) +from plugins import VideoEditorPlugin + +class VideoResultItemWidget(QWidget): + """A widget to display a generated video with a hover-to-play preview and insert button.""" + def __init__(self, video_path, plugin, parent=None): + super().__init__(parent) + self.video_path = video_path + self.plugin = plugin + self.app = plugin.app + self.duration = 0.0 + self.has_audio = False + + self.setMinimumSize(200, 180) + self.setMaximumHeight(190) + + layout = QVBoxLayout(self) + layout.setContentsMargins(5, 5, 5, 5) + layout.setSpacing(5) + + self.media_player = QMediaPlayer() + self.video_widget = QVideoWidget() + self.video_widget.setFixedSize(160, 90) + self.media_player.setVideoOutput(self.video_widget) + self.media_player.setSource(QUrl.fromLocalFile(self.video_path)) + self.media_player.setLoops(QMediaPlayer.Loops.Infinite) + + self.info_label = QLabel(os.path.basename(video_path)) + self.info_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.info_label.setWordWrap(True) + + self.insert_button = QPushButton("Insert into Timeline") + self.insert_button.clicked.connect(self.on_insert) + + h_layout = QHBoxLayout() + h_layout.addStretch() + h_layout.addWidget(self.video_widget) + h_layout.addStretch() + + layout.addLayout(h_layout) + layout.addWidget(self.info_label) + layout.addWidget(self.insert_button) + self.probe_video() + + def probe_video(self): + try: + probe = ffmpeg.probe(self.video_path) + self.duration = float(probe['format']['duration']) + self.has_audio = any(s['codec_type'] == 'audio' for s in probe.get('streams', [])) + self.info_label.setText(f"{os.path.basename(self.video_path)}\n({self.duration:.2f}s)") + except Exception as e: + self.info_label.setText(f"Error probing:\n{os.path.basename(self.video_path)}") + print(f"Error probing video {self.video_path}: {e}") + + def enterEvent(self, event): + super().enterEvent(event) + self.media_player.play() + if not self.plugin.active_region or self.duration == 0: return + start_sec, _ = self.plugin.active_region + timeline = self.app.timeline_widget + video_rect, audio_rect = None, None + x = timeline.sec_to_x(start_sec) + w = int(self.duration * timeline.pixels_per_second) + if self.plugin.insert_on_new_track: + video_y = timeline.TIMESCALE_HEIGHT + video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) + if self.has_audio: + audio_y = timeline.audio_tracks_y_start + self.app.timeline.num_audio_tracks * timeline.TRACK_HEIGHT + audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) + else: + v_track_idx = 1 + visual_v_idx = self.app.timeline.num_video_tracks - v_track_idx + video_y = timeline.video_tracks_y_start + visual_v_idx * timeline.TRACK_HEIGHT + video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) + if self.has_audio: + a_track_idx = 1 + visual_a_idx = a_track_idx - 1 + audio_y = timeline.audio_tracks_y_start + visual_a_idx * timeline.TRACK_HEIGHT + audio_rect = QRectF(x, audio_y, w, timeline.TRACK_HEIGHT) + timeline.set_hover_preview_rects(video_rect, audio_rect) + + def leaveEvent(self, event): + super().leaveEvent(event) + self.media_player.pause() + self.media_player.setPosition(0) + self.app.timeline_widget.set_hover_preview_rects(None, None) + + def on_insert(self): + self.media_player.stop() + self.media_player.setSource(QUrl()) + self.media_player.setVideoOutput(None) + self.app.timeline_widget.set_hover_preview_rects(None, None) + self.plugin.insert_generated_clip(self.video_path) + + +class QueueTableWidget(QTableWidget): + rowsMoved = pyqtSignal(int, int) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setDragEnabled(True) + self.setAcceptDrops(True) + self.setDropIndicatorShown(True) + self.setDragDropMode(self.DragDropMode.InternalMove) + self.setSelectionBehavior(self.SelectionBehavior.SelectRows) + self.setSelectionMode(self.SelectionMode.SingleSelection) + + def dropEvent(self, event: QDropEvent): + if event.source() == self and event.dropAction() == Qt.DropAction.MoveAction: + source_row = self.currentRow() + target_item = self.itemAt(event.position().toPoint()) + dest_row = target_item.row() if target_item else self.rowCount() + if source_row < dest_row: dest_row -=1 + if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row) + event.acceptProposedAction() + else: + super().dropEvent(event) + +class HoverVideoPreview(QWidget): + def __init__(self, player, video_widget, parent=None): + super().__init__(parent) + self.player = player + layout = QVBoxLayout(self) + layout.setContentsMargins(0,0,0,0) + layout.addWidget(video_widget) + self.setFixedSize(160, 90) + + def enterEvent(self, event): + super().enterEvent(event) + if self.player.source().isValid(): + self.player.play() + + def leaveEvent(self, event): + super().leaveEvent(event) + self.player.pause() + self.player.setPosition(0) + +class Worker(QObject): + progress = pyqtSignal(list) + status = pyqtSignal(str) + preview = pyqtSignal(object) + output = pyqtSignal() + finished = pyqtSignal() + error = pyqtSignal(str) + def __init__(self, plugin, state): + super().__init__() + self.plugin = plugin + self.wgp = plugin.wgp + self.state = state + self._is_running = True + self._last_progress_phase = None + self._last_preview = None + + def run(self): + def generation_target(): + try: + for _ in self.wgp.process_tasks(self.state): + if self._is_running: self.output.emit() + else: break + except Exception as e: + import traceback + print("Error in generation thread:") + traceback.print_exc() + if "gradio.Error" in str(type(e)): self.error.emit(str(e)) + else: self.error.emit(f"An unexpected error occurred: {e}") + finally: + self._is_running = False + gen_thread = threading.Thread(target=generation_target, daemon=True) + gen_thread.start() + while self._is_running: + gen = self.state.get('gen', {}) + current_phase = gen.get("progress_phase") + if current_phase and current_phase != self._last_progress_phase: + self._last_progress_phase = current_phase + phase_name, step = current_phase + total_steps = gen.get("num_inference_steps", 1) + high_level_status = gen.get("progress_status", "") + status_msg = self.wgp.merge_status_context(high_level_status, phase_name) + progress_args = [(step, total_steps), status_msg] + self.progress.emit(progress_args) + preview_img = gen.get('preview') + if preview_img is not None and preview_img is not self._last_preview: + self._last_preview = preview_img + self.preview.emit(preview_img) + gen['preview'] = None + time.sleep(0.1) + gen_thread.join() + self.finished.emit() + + +class WgpDesktopPluginWidget(QWidget): + def __init__(self, plugin): + super().__init__() + self.plugin = plugin + self.wgp = plugin.wgp + self.widgets = {} + self.state = {} + self.worker = None + self.thread = None + self.lora_map = {} + self.full_resolution_choices = [] + self.main_config = {} + self.processed_files = set() + + self.load_main_config() + self.setup_ui() + self.apply_initial_config() + self.connect_signals() + self.init_wgp_state() + + def setup_ui(self): + main_layout = QVBoxLayout(self) + self.header_info = QLabel("Header Info") + main_layout.addWidget(self.header_info) + self.tabs = QTabWidget() + main_layout.addWidget(self.tabs) + self.setup_generator_tab() + self.setup_config_tab() + + def create_widget(self, widget_class, name, *args, **kwargs): + widget = widget_class(*args, **kwargs) + self.widgets[name] = widget + return widget + + def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1.0, precision=1): + container = QWidget() + hbox = QHBoxLayout(container) + hbox.setContentsMargins(0, 0, 0, 0) + slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal) + slider.setRange(min_val, max_val) + slider.setValue(int(initial_val * scale)) + value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}") + value_label.setMinimumWidth(50) + slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}")) + hbox.addWidget(slider) + hbox.addWidget(value_label) + return container + + def _create_file_input(self, name, label_text): + container = self.create_widget(QWidget, f"{name}_container") + vbox = QVBoxLayout(container) + vbox.setContentsMargins(0, 0, 0, 0) + vbox.setSpacing(5) + + input_widget = QWidget() + hbox = QHBoxLayout(input_widget) + hbox.setContentsMargins(0, 0, 0, 0) + + line_edit = self.create_widget(QLineEdit, name) + line_edit.setPlaceholderText("No file selected or path pasted") + button = QPushButton("Browse...") + + def open_dialog(): + if "refs" in name: + filenames, _ = QFileDialog.getOpenFileNames(self, f"Select {label_text}", filter="Image Files (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)") + if filenames: line_edit.setText(";".join(filenames)) + else: + filter_str = "All Files (*)" + if 'video' in name: + filter_str = "Video Files (*.mp4 *.mkv *.mov *.avi);;All Files (*)" + elif 'image' in name: + filter_str = "Image Files (*.png *.jpg *.jpeg *.bmp *.webp);;All Files (*)" + elif 'audio' in name: + filter_str = "Audio Files (*.wav *.mp3 *.flac);;All Files (*)" + + filename, _ = QFileDialog.getOpenFileName(self, f"Select {label_text}", filter=filter_str) + if filename: line_edit.setText(filename) + + button.clicked.connect(open_dialog) + clear_button = QPushButton("X") + clear_button.setFixedWidth(30) + clear_button.clicked.connect(lambda: line_edit.clear()) + + hbox.addWidget(QLabel(f"{label_text}:")) + hbox.addWidget(line_edit, 1) + hbox.addWidget(button) + hbox.addWidget(clear_button) + vbox.addWidget(input_widget) + + # Preview Widget + preview_container = self.create_widget(QWidget, f"{name}_preview_container") + preview_hbox = QHBoxLayout(preview_container) + preview_hbox.setContentsMargins(0, 0, 0, 0) + preview_hbox.addStretch() + + is_image_input = 'image' in name and 'audio' not in name + is_video_input = 'video' in name and 'audio' not in name + + if is_image_input: + preview_widget = self.create_widget(QLabel, f"{name}_preview") + preview_widget.setAlignment(Qt.AlignmentFlag.AlignCenter) + preview_widget.setFixedSize(160, 90) + preview_widget.setStyleSheet("border: 1px solid #cccccc; background-color: #f0f0f0;") + preview_widget.setText("Image Preview") + preview_hbox.addWidget(preview_widget) + elif is_video_input: + media_player = QMediaPlayer() + video_widget = QVideoWidget() + video_widget.setFixedSize(160, 90) + media_player.setVideoOutput(video_widget) + media_player.setLoops(QMediaPlayer.Loops.Infinite) + + self.widgets[f"{name}_player"] = media_player + + preview_widget = HoverVideoPreview(media_player, video_widget) + preview_hbox.addWidget(preview_widget) + else: + preview_widget = self.create_widget(QLabel, f"{name}_preview") + preview_widget.setText("No preview available") + preview_hbox.addWidget(preview_widget) + + preview_hbox.addStretch() + vbox.addWidget(preview_container) + preview_container.setVisible(False) + + def update_preview(path): + container = self.widgets.get(f"{name}_preview_container") + if not container: return + + first_path = path.split(';')[0] if path else '' + + if not first_path or not os.path.exists(first_path): + container.setVisible(False) + if is_video_input: + player = self.widgets.get(f"{name}_player") + if player: player.setSource(QUrl()) + return + + container.setVisible(True) + + if is_image_input: + preview_label = self.widgets.get(f"{name}_preview") + if preview_label: + pixmap = QPixmap(first_path) + if not pixmap.isNull(): + preview_label.setPixmap(pixmap.scaled(preview_label.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)) + else: + preview_label.setText("Invalid Image") + + elif is_video_input: + player = self.widgets.get(f"{name}_player") + if player: + player.setSource(QUrl.fromLocalFile(first_path)) + + else: # Audio or other + preview_label = self.widgets.get(f"{name}_preview") + if preview_label: + preview_label.setText(os.path.basename(path)) + + line_edit.textChanged.connect(update_preview) + + return container + + def setup_generator_tab(self): + gen_tab = QWidget() + self.tabs.addTab(gen_tab, "Video Generator") + gen_layout = QHBoxLayout(gen_tab) + left_panel = QWidget() + left_layout = QVBoxLayout(left_panel) + gen_layout.addWidget(left_panel, 1) + right_panel = QWidget() + right_layout = QVBoxLayout(right_panel) + gen_layout.addWidget(right_panel, 1) + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + left_layout.addWidget(scroll_area) + options_widget = QWidget() + scroll_area.setWidget(options_widget) + options_layout = QVBoxLayout(options_widget) + model_layout = QHBoxLayout() + self.widgets['model_family'] = QComboBox() + self.widgets['model_base_type_choice'] = QComboBox() + self.widgets['model_choice'] = QComboBox() + model_layout.addWidget(QLabel("Model:")) + model_layout.addWidget(self.widgets['model_family'], 2) + model_layout.addWidget(self.widgets['model_base_type_choice'], 3) + model_layout.addWidget(self.widgets['model_choice'], 3) + options_layout.addLayout(model_layout) + options_layout.addWidget(QLabel("Prompt:")) + self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100) + options_layout.addWidget(self.widgets['prompt']) + options_layout.addWidget(QLabel("Negative Prompt:")) + self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60) + options_layout.addWidget(self.widgets['negative_prompt']) + basic_group = QGroupBox("Basic Options") + basic_layout = QFormLayout(basic_group) + res_container = QWidget() + res_hbox = QHBoxLayout(res_container) + res_hbox.setContentsMargins(0, 0, 0, 0) + res_hbox.addWidget(self.create_widget(QComboBox, 'resolution_group'), 2) + res_hbox.addWidget(self.create_widget(QComboBox, 'resolution'), 3) + basic_layout.addRow("Resolution:", res_container) + basic_layout.addRow("Video Length:", self._create_slider_with_label('video_length', 1, 737, 81, 1.0, 0)) + basic_layout.addRow("Inference Steps:", self._create_slider_with_label('num_inference_steps', 1, 100, 30, 1.0, 0)) + basic_layout.addRow("Seed:", self.create_widget(QLineEdit, 'seed', '-1')) + options_layout.addWidget(basic_group) + mode_options_group = QGroupBox("Generation Mode & Input Options") + mode_options_layout = QVBoxLayout(mode_options_group) + mode_hbox = QHBoxLayout() + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_t', "Text Prompt Only")) + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_s', "Start with Image")) + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_v', "Continue Video")) + mode_hbox.addWidget(self.create_widget(QRadioButton, 'mode_l', "Continue Last Video")) + self.widgets['mode_t'].setChecked(True) + mode_options_layout.addLayout(mode_hbox) + options_hbox = QHBoxLayout() + options_hbox.addWidget(self.create_widget(QCheckBox, 'image_end_checkbox', "Use End Image")) + options_hbox.addWidget(self.create_widget(QCheckBox, 'control_video_checkbox', "Use Control Video")) + options_hbox.addWidget(self.create_widget(QCheckBox, 'ref_image_checkbox', "Use Reference Image(s)")) + mode_options_layout.addLayout(options_hbox) + options_layout.addWidget(mode_options_group) + inputs_group = QGroupBox("Inputs") + inputs_layout = QVBoxLayout(inputs_group) + inputs_layout.addWidget(self._create_file_input('image_start', "Start Image")) + inputs_layout.addWidget(self._create_file_input('image_end', "End Image")) + inputs_layout.addWidget(self._create_file_input('video_source', "Source Video")) + inputs_layout.addWidget(self._create_file_input('video_guide', "Control Video")) + inputs_layout.addWidget(self._create_file_input('video_mask', "Video Mask")) + inputs_layout.addWidget(self._create_file_input('image_refs', "Reference Image(s)")) + denoising_row = QFormLayout() + denoising_row.addRow("Denoising Strength:", self._create_slider_with_label('denoising_strength', 0, 100, 50, 100.0, 2)) + inputs_layout.addLayout(denoising_row) + options_layout.addWidget(inputs_group) + self.advanced_group = self.create_widget(QGroupBox, 'advanced_group', "Advanced Options") + self.advanced_group.setCheckable(True) + self.advanced_group.setChecked(False) + advanced_layout = QVBoxLayout(self.advanced_group) + advanced_tabs = self.create_widget(QTabWidget, 'advanced_tabs') + advanced_layout.addWidget(advanced_tabs) + self._setup_adv_tab_general(advanced_tabs) + self._setup_adv_tab_loras(advanced_tabs) + self._setup_adv_tab_speed(advanced_tabs) + self._setup_adv_tab_postproc(advanced_tabs) + self._setup_adv_tab_audio(advanced_tabs) + self._setup_adv_tab_quality(advanced_tabs) + self._setup_adv_tab_sliding_window(advanced_tabs) + self._setup_adv_tab_misc(advanced_tabs) + options_layout.addWidget(self.advanced_group) + + btn_layout = QHBoxLayout() + self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate") + self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue") + self.generate_btn.setEnabled(True) + self.add_to_queue_btn.setEnabled(False) + btn_layout.addWidget(self.generate_btn) + btn_layout.addWidget(self.add_to_queue_btn) + right_layout.addLayout(btn_layout) + self.status_label = self.create_widget(QLabel, 'status_label', "Idle") + right_layout.addWidget(self.status_label) + self.progress_bar = self.create_widget(QProgressBar, 'progress_bar') + right_layout.addWidget(self.progress_bar) + preview_group = self.create_widget(QGroupBox, 'preview_group', "Preview") + preview_group.setCheckable(True) + preview_group.setStyleSheet("QGroupBox { border: 1px solid #cccccc; }") + preview_group_layout = QVBoxLayout(preview_group) + self.preview_image = self.create_widget(QLabel, 'preview_image', "") + self.preview_image.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.preview_image.setMinimumSize(200, 200) + preview_group_layout.addWidget(self.preview_image) + right_layout.addWidget(preview_group) + + results_group = QGroupBox("Generated Clips") + results_group.setSizePolicy(QSizePolicy.Policy.Expanding, QSizePolicy.Policy.Expanding) + results_layout = QVBoxLayout(results_group) + self.results_list = self.create_widget(QListWidget, 'results_list') + self.results_list.setFlow(QListWidget.Flow.LeftToRight) + self.results_list.setWrapping(True) + self.results_list.setResizeMode(QListWidget.ResizeMode.Adjust) + self.results_list.setSpacing(10) + results_layout.addWidget(self.results_list) + right_layout.addWidget(results_group) + + right_layout.addWidget(QLabel("Queue:")) + self.queue_table = self.create_widget(QueueTableWidget, 'queue_table') + right_layout.addWidget(self.queue_table) + queue_btn_layout = QHBoxLayout() + self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected") + self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue") + self.abort_btn = self.create_widget(QPushButton, 'abort_btn', "Abort") + queue_btn_layout.addWidget(self.remove_queue_btn) + queue_btn_layout.addWidget(self.clear_queue_btn) + queue_btn_layout.addWidget(self.abort_btn) + right_layout.addLayout(queue_btn_layout) + + def _setup_adv_tab_general(self, tabs): + tab = QWidget() + tabs.addTab(tab, "General") + layout = QFormLayout(tab) + self.widgets['adv_general_layout'] = layout + guidance_group = QGroupBox("Guidance") + guidance_layout = self.create_widget(QFormLayout, 'guidance_layout', guidance_group) + guidance_layout.addRow("Guidance (CFG):", self._create_slider_with_label('guidance_scale', 10, 200, 5.0, 10.0, 1)) + self.widgets['guidance_phases_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Guidance Phases:", self.create_widget(QComboBox, 'guidance_phases')) + self.widgets['guidance2_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Guidance 2:", self._create_slider_with_label('guidance2_scale', 10, 200, 5.0, 10.0, 1)) + self.widgets['guidance3_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Guidance 3:", self._create_slider_with_label('guidance3_scale', 10, 200, 5.0, 10.0, 1)) + self.widgets['switch_thresh_row_index'] = guidance_layout.rowCount() + guidance_layout.addRow("Switch Threshold:", self._create_slider_with_label('switch_threshold', 0, 1000, 0, 1.0, 0)) + layout.addRow(guidance_group) + nag_group = self.create_widget(QGroupBox, 'nag_group', "NAG (Negative Adversarial Guidance)") + nag_layout = QFormLayout(nag_group) + nag_layout.addRow("NAG Scale:", self._create_slider_with_label('NAG_scale', 10, 200, 1.0, 10.0, 1)) + nag_layout.addRow("NAG Tau:", self._create_slider_with_label('NAG_tau', 10, 50, 3.5, 10.0, 1)) + nag_layout.addRow("NAG Alpha:", self._create_slider_with_label('NAG_alpha', 0, 20, 0.5, 10.0, 1)) + layout.addRow(nag_group) + self.widgets['solver_row_container'] = QWidget() + solver_hbox = QHBoxLayout(self.widgets['solver_row_container']) + solver_hbox.setContentsMargins(0,0,0,0) + solver_hbox.addWidget(QLabel("Sampler Solver:")) + solver_hbox.addWidget(self.create_widget(QComboBox, 'sample_solver')) + layout.addRow(self.widgets['solver_row_container']) + self.widgets['flow_shift_row_index'] = layout.rowCount() + layout.addRow("Shift Scale:", self._create_slider_with_label('flow_shift', 10, 250, 3.0, 10.0, 1)) + self.widgets['audio_guidance_row_index'] = layout.rowCount() + layout.addRow("Audio Guidance:", self._create_slider_with_label('audio_guidance_scale', 10, 200, 4.0, 10.0, 1)) + self.widgets['repeat_generation_row_index'] = layout.rowCount() + layout.addRow("Repeat Generations:", self._create_slider_with_label('repeat_generation', 1, 25, 1, 1.0, 0)) + combo = self.create_widget(QComboBox, 'multi_images_gen_type') + combo.addItem("Generate all combinations", 0) + combo.addItem("Match images and texts", 1) + self.widgets['multi_images_gen_type_row_index'] = layout.rowCount() + layout.addRow("Multi-Image Mode:", combo) + + def _setup_adv_tab_loras(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Loras") + layout = QVBoxLayout(tab) + layout.addWidget(QLabel("Available Loras (Ctrl+Click to select multiple):")) + lora_list = self.create_widget(QListWidget, 'activated_loras') + lora_list.setSelectionMode(QListWidget.SelectionMode.MultiSelection) + layout.addWidget(lora_list) + layout.addWidget(QLabel("Loras Multipliers:")) + layout.addWidget(self.create_widget(QTextEdit, 'loras_multipliers')) + + def _setup_adv_tab_speed(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Speed") + layout = QFormLayout(tab) + combo = self.create_widget(QComboBox, 'skip_steps_cache_type') + combo.addItem("None", "") + combo.addItem("Tea Cache", "tea") + combo.addItem("Mag Cache", "mag") + layout.addRow("Cache Type:", combo) + combo = self.create_widget(QComboBox, 'skip_steps_multiplier') + combo.addItem("x1.5 speed up", 1.5) + combo.addItem("x1.75 speed up", 1.75) + combo.addItem("x2.0 speed up", 2.0) + combo.addItem("x2.25 speed up", 2.25) + combo.addItem("x2.5 speed up", 2.5) + layout.addRow("Acceleration:", combo) + layout.addRow("Start %:", self._create_slider_with_label('skip_steps_start_step_perc', 0, 100, 0, 1.0, 0)) + + def _setup_adv_tab_postproc(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Post-Processing") + layout = QFormLayout(tab) + combo = self.create_widget(QComboBox, 'temporal_upsampling') + combo.addItem("Disabled", "") + combo.addItem("Rife x2 frames/s", "rife2") + combo.addItem("Rife x4 frames/s", "rife4") + layout.addRow("Temporal Upsampling:", combo) + combo = self.create_widget(QComboBox, 'spatial_upsampling') + combo.addItem("Disabled", "") + combo.addItem("Lanczos x1.5", "lanczos1.5") + combo.addItem("Lanczos x2.0", "lanczos2") + layout.addRow("Spatial Upsampling:", combo) + layout.addRow("Film Grain Intensity:", self._create_slider_with_label('film_grain_intensity', 0, 100, 0, 100.0, 2)) + layout.addRow("Film Grain Saturation:", self._create_slider_with_label('film_grain_saturation', 0, 100, 0.5, 100.0, 2)) + + def _setup_adv_tab_audio(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Audio") + layout = QFormLayout(tab) + combo = self.create_widget(QComboBox, 'MMAudio_setting') + combo.addItem("Disabled", 0) + combo.addItem("Enabled", 1) + layout.addRow("MMAudio:", combo) + layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_prompt', placeholderText="MMAudio Prompt")) + layout.addWidget(self.create_widget(QLineEdit, 'MMAudio_neg_prompt', placeholderText="MMAudio Negative Prompt")) + layout.addRow(self._create_file_input('audio_source', "Custom Soundtrack")) + + def _setup_adv_tab_quality(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Quality") + layout = QVBoxLayout(tab) + slg_group = self.create_widget(QGroupBox, 'slg_group', "Skip Layer Guidance") + slg_layout = QFormLayout(slg_group) + slg_combo = self.create_widget(QComboBox, 'slg_switch') + slg_combo.addItem("OFF", 0) + slg_combo.addItem("ON", 1) + slg_layout.addRow("Enable SLG:", slg_combo) + slg_layout.addRow("Start %:", self._create_slider_with_label('slg_start_perc', 0, 100, 10, 1.0, 0)) + slg_layout.addRow("End %:", self._create_slider_with_label('slg_end_perc', 0, 100, 90, 1.0, 0)) + layout.addWidget(slg_group) + quality_form = QFormLayout() + self.widgets['quality_form_layout'] = quality_form + apg_combo = self.create_widget(QComboBox, 'apg_switch') + apg_combo.addItem("OFF", 0) + apg_combo.addItem("ON", 1) + self.widgets['apg_switch_row_index'] = quality_form.rowCount() + quality_form.addRow("Adaptive Projected Guidance:", apg_combo) + cfg_star_combo = self.create_widget(QComboBox, 'cfg_star_switch') + cfg_star_combo.addItem("OFF", 0) + cfg_star_combo.addItem("ON", 1) + self.widgets['cfg_star_switch_row_index'] = quality_form.rowCount() + quality_form.addRow("Classifier-Free Guidance Star:", cfg_star_combo) + self.widgets['cfg_zero_step_row_index'] = quality_form.rowCount() + quality_form.addRow("CFG Zero below Layer:", self._create_slider_with_label('cfg_zero_step', -1, 39, -1, 1.0, 0)) + combo = self.create_widget(QComboBox, 'min_frames_if_references') + combo.addItem("Disabled (1 frame)", 1) + combo.addItem("Generate 5 frames", 5) + combo.addItem("Generate 9 frames", 9) + combo.addItem("Generate 13 frames", 13) + combo.addItem("Generate 17 frames", 17) + self.widgets['min_frames_if_references_row_index'] = quality_form.rowCount() + quality_form.addRow("Min Frames for Quality:", combo) + layout.addLayout(quality_form) + + def _setup_adv_tab_sliding_window(self, tabs): + tab = QWidget() + self.widgets['sliding_window_tab_index'] = tabs.count() + tabs.addTab(tab, "Sliding Window") + layout = QFormLayout(tab) + layout.addRow("Window Size:", self._create_slider_with_label('sliding_window_size', 5, 257, 129, 1.0, 0)) + layout.addRow("Overlap:", self._create_slider_with_label('sliding_window_overlap', 1, 97, 5, 1.0, 0)) + layout.addRow("Color Correction:", self._create_slider_with_label('sliding_window_color_correction_strength', 0, 100, 0, 100.0, 2)) + layout.addRow("Overlap Noise:", self._create_slider_with_label('sliding_window_overlap_noise', 0, 150, 20, 1.0, 0)) + layout.addRow("Discard Last Frames:", self._create_slider_with_label('sliding_window_discard_last_frames', 0, 20, 0, 1.0, 0)) + + def _setup_adv_tab_misc(self, tabs): + tab = QWidget() + tabs.addTab(tab, "Misc") + layout = QFormLayout(tab) + self.widgets['misc_layout'] = layout + riflex_combo = self.create_widget(QComboBox, 'RIFLEx_setting') + riflex_combo.addItem("Auto", 0) + riflex_combo.addItem("Always ON", 1) + riflex_combo.addItem("Always OFF", 2) + self.widgets['riflex_row_index'] = layout.rowCount() + layout.addRow("RIFLEx Setting:", riflex_combo) + fps_combo = self.create_widget(QComboBox, 'force_fps') + layout.addRow("Force FPS:", fps_combo) + profile_combo = self.create_widget(QComboBox, 'override_profile') + profile_combo.addItem("Default Profile", -1) + for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val) + layout.addRow("Override Memory Profile:", profile_combo) + combo = self.create_widget(QComboBox, 'multi_prompts_gen_type') + combo.addItem("Generate new Video per line", 0) + combo.addItem("Use line for new Sliding Window", 1) + layout.addRow("Multi-Prompt Mode:", combo) + + def setup_config_tab(self): + config_tab = QWidget() + self.tabs.addTab(config_tab, "Configuration") + main_layout = QVBoxLayout(config_tab) + self.config_status_label = QLabel("Apply changes for them to take effect. Some may require a restart.") + main_layout.addWidget(self.config_status_label) + config_tabs = QTabWidget() + main_layout.addWidget(config_tabs) + config_tabs.addTab(self._create_general_config_tab(), "General") + config_tabs.addTab(self._create_performance_config_tab(), "Performance") + config_tabs.addTab(self._create_extensions_config_tab(), "Extensions") + config_tabs.addTab(self._create_outputs_config_tab(), "Outputs") + config_tabs.addTab(self._create_notifications_config_tab(), "Notifications") + self.apply_config_btn = QPushButton("Apply Changes") + self.apply_config_btn.clicked.connect(self._on_apply_config_changes) + main_layout.addWidget(self.apply_config_btn) + + def _create_scrollable_form_tab(self): + tab_widget = QWidget() + scroll_area = QScrollArea() + scroll_area.setWidgetResizable(True) + layout = QVBoxLayout(tab_widget) + layout.addWidget(scroll_area) + content_widget = QWidget() + form_layout = QFormLayout(content_widget) + scroll_area.setWidget(content_widget) + return tab_widget, form_layout + + def _create_config_combo(self, form_layout, label, key, choices, default_value): + combo = QComboBox() + for text, data in choices: combo.addItem(text, data) + index = combo.findData(self.wgp.server_config.get(key, default_value)) + if index != -1: combo.setCurrentIndex(index) + self.widgets[f'config_{key}'] = combo + form_layout.addRow(label, combo) + + def _create_config_slider(self, form_layout, label, key, min_val, max_val, default_value, step=1): + container = QWidget() + hbox = QHBoxLayout(container) + hbox.setContentsMargins(0,0,0,0) + slider = QSlider(Qt.Orientation.Horizontal) + slider.setRange(min_val, max_val) + slider.setSingleStep(step) + slider.setValue(self.wgp.server_config.get(key, default_value)) + value_label = QLabel(str(slider.value())) + value_label.setMinimumWidth(40) + slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v))) + hbox.addWidget(slider) + hbox.addWidget(value_label) + self.widgets[f'config_{key}'] = slider + form_layout.addRow(label, container) + + def _create_config_checklist(self, form_layout, label, key, choices, default_value): + list_widget = QListWidget() + list_widget.setMinimumHeight(100) + current_values = self.wgp.server_config.get(key, default_value) + for text, data in choices: + item = QListWidgetItem(text) + item.setData(Qt.ItemDataRole.UserRole, data) + item.setFlags(item.flags() | Qt.ItemFlag.ItemIsUserCheckable) + item.setCheckState(Qt.CheckState.Checked if data in current_values else Qt.CheckState.Unchecked) + list_widget.addItem(item) + self.widgets[f'config_{key}'] = list_widget + form_layout.addRow(label, list_widget) + + def _create_config_textbox(self, form_layout, label, key, default_value, multi_line=False): + if multi_line: + textbox = QTextEdit(default_value) + textbox.setAcceptRichText(False) + else: + textbox = QLineEdit(default_value) + self.widgets[f'config_{key}'] = textbox + form_layout.addRow(label, textbox) + + def _create_general_config_tab(self): + tab, form = self._create_scrollable_form_tab() + _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False) + self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types) + self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1) + self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0) + self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto") + self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata") + self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], []) + self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5) + self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0) + self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1) + checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths)) + checkpoints_textbox = QTextEdit() + checkpoints_textbox.setPlainText(checkpoints_paths_text) + checkpoints_textbox.setAcceptRichText(False) + checkpoints_textbox.setMinimumHeight(60) + self.widgets['config_checkpoints_paths'] = checkpoints_textbox + form.addRow("Checkpoints Paths:", checkpoints_textbox) + self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default") + return tab + + def _create_performance_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Transformer Quantization:", "transformer_quantization", [("Scaled Int8 (recommended)", "int8"), ("16-bit (no quantization)", "bf16")], "int8") + self._create_config_combo(form, "Transformer Data Type:", "transformer_dtype_policy", [("Best Supported by Hardware", ""), ("FP16", "fp16"), ("BF16", "bf16")], "") + self._create_config_combo(form, "Transformer Calculation:", "mixed_precision", [("16-bit only", "0"), ("Mixed 16/32-bit (better quality)", "1")], "0") + self._create_config_combo(form, "Text Encoder:", "text_encoder_quantization", [("16-bit (more RAM, better quality)", "bf16"), ("8-bit (less RAM)", "int8")], "int8") + self._create_config_combo(form, "VAE Precision:", "vae_precision", [("16-bit (faster, less VRAM)", "16"), ("32-bit (slower, better quality)", "32")], "16") + self._create_config_combo(form, "Compile Transformer:", "compile", [("On (requires Triton)", "transformer"), ("Off", "")], "") + self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl") + self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0) + self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1) + self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM) + self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100) + release_ram_btn = QPushButton("Force Release Models from RAM") + release_ram_btn.clicked.connect(self._on_release_ram) + form.addRow(release_ram_btn) + return tab + + def _create_extensions_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Prompt Enhancer:", "enhancer_enabled", [("Off", 0), ("Florence 2 + Llama 3.2", 1), ("Florence 2 + Joy Caption (uncensored)", 2)], 0) + self._create_config_combo(form, "Enhancer Mode:", "enhancer_mode", [("Automatic on Generate", 0), ("On Demand Only", 1)], 0) + self._create_config_combo(form, "MMAudio:", "mmaudio_enabled", [("Off", 0), ("Enabled (unloaded after use)", 1), ("Enabled (persistent in RAM)", 2)], 0) + return tab + + def _create_outputs_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Video Codec:", "video_output_codec", [("x265 Balanced", 'libx265_28'), ("x264 Balanced", 'libx264_8'), ("x265 High Quality", 'libx265_8'), ("x264 High Quality", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], 'libx264_8') + self._create_config_combo(form, "Image Codec:", "image_output_codec", [("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], 'jpeg_95') + self._create_config_textbox(form, "Video Output Folder:", "save_path", "outputs") + self._create_config_textbox(form, "Image Output Folder:", "image_save_path", "outputs") + return tab + + def _create_notifications_config_tab(self): + tab, form = self._create_scrollable_form_tab() + self._create_config_combo(form, "Notification Sound:", "notification_sound_enabled", [("On", 1), ("Off", 0)], 0) + self._create_config_slider(form, "Sound Volume:", "notification_sound_volume", 0, 100, 50, 5) + return tab + + def init_wgp_state(self): + wgp = self.wgp + initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type) + dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types + _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) + all_model_ids = [m[1] for m in all_models] + if initial_model not in all_model_ids: initial_model = wgp.transformer_type + state_dict = {} + state_dict["model_filename"] = wgp.get_model_filename(initial_model, wgp.transformer_quantization, wgp.transformer_dtype_policy) + state_dict["model_type"] = initial_model + state_dict["advanced"] = wgp.advanced + state_dict["last_model_per_family"] = wgp.server_config.get("last_model_per_family", {}) + state_dict["last_model_per_type"] = wgp.server_config.get("last_model_per_type", {}) + state_dict["last_resolution_per_group"] = wgp.server_config.get("last_resolution_per_group", {}) + state_dict["gen"] = {"queue": []} + self.state = state_dict + self.advanced_group.setChecked(wgp.advanced) + self.update_model_dropdowns(initial_model) + self.refresh_ui_from_model_change(initial_model) + self._update_input_visibility() + + def update_model_dropdowns(self, current_model_type): + wgp = self.wgp + family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type) + for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]: + combo = self.widgets[combo_name] + combo.blockSignals(True) + combo.clear() + if mock.choices: + for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key) + index = combo.findData(mock.value) + if index != -1: combo.setCurrentIndex(index) + + is_visible = True + if hasattr(mock, 'kwargs') and isinstance(mock.kwargs, dict): + is_visible = mock.kwargs.get('visible', True) + elif hasattr(mock, 'visible'): + is_visible = mock.visible + combo.setVisible(is_visible) + + combo.blockSignals(False) + + def refresh_ui_from_model_change(self, model_type): + """Update UI controls with default settings when the model is changed.""" + wgp = self.wgp + self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode)) + ui_defaults = wgp.get_default_settings(model_type) + wgp.set_model_settings(self.state, model_type, ui_defaults) + + model_def = wgp.get_model_def(model_type) + base_model_type = wgp.get_base_model_type(model_type) + model_filename = self.state.get('model_filename', '') + + image_outputs = model_def.get("image_outputs", False) + vace = wgp.test_vace_module(model_type) + t2v = base_model_type in ['t2v', 't2v_2_2'] + i2v = wgp.test_class_i2v(model_type) + fantasy = base_model_type in ["fantasy"] + multitalk = model_def.get("multitalk_class", False) + any_audio_guidance = fantasy or multitalk + sliding_window_enabled = wgp.test_any_sliding_window(model_type) + recammaster = base_model_type in ["recam_1.3B"] + ltxv = "ltxv" in model_filename + diffusion_forcing = "diffusion_forcing" in model_filename + any_skip_layer_guidance = model_def.get("skip_layer_guidance", False) + any_cfg_zero = model_def.get("cfg_zero", False) + any_cfg_star = model_def.get("cfg_star", False) + any_apg = model_def.get("adaptive_projected_guidance", False) + v2i_switch_supported = model_def.get("v2i_switch_supported", False) + + self._update_generation_mode_visibility(model_def) + + for widget in self.widgets.values(): + if hasattr(widget, 'blockSignals'): widget.blockSignals(True) + + self.widgets['prompt'].setText(ui_defaults.get("prompt", "")) + self.widgets['negative_prompt'].setText(ui_defaults.get("negative_prompt", "")) + self.widgets['seed'].setText(str(ui_defaults.get("seed", -1))) + + video_length_val = ui_defaults.get("video_length", 81) + self.widgets['video_length'].setValue(video_length_val) + self.widgets['video_length_label'].setText(str(video_length_val)) + + steps_val = ui_defaults.get("num_inference_steps", 30) + self.widgets['num_inference_steps'].setValue(steps_val) + self.widgets['num_inference_steps_label'].setText(str(steps_val)) + + self.widgets['resolution_group'].blockSignals(True) + self.widgets['resolution'].blockSignals(True) + + current_res_choice = ui_defaults.get("resolution") + model_resolutions = model_def.get("resolutions", None) + self.full_resolution_choices, current_res_choice = wgp.get_resolution_choices(current_res_choice, model_resolutions) + available_groups, selected_group_resolutions, selected_group = wgp.group_resolutions(model_def, self.full_resolution_choices, current_res_choice) + + self.widgets['resolution_group'].clear() + self.widgets['resolution_group'].addItems(available_groups) + group_index = self.widgets['resolution_group'].findText(selected_group) + if group_index != -1: + self.widgets['resolution_group'].setCurrentIndex(group_index) + + self.widgets['resolution'].clear() + for label, value in selected_group_resolutions: + self.widgets['resolution'].addItem(label, value) + res_index = self.widgets['resolution'].findData(current_res_choice) + if res_index != -1: + self.widgets['resolution'].setCurrentIndex(res_index) + + self.widgets['resolution_group'].blockSignals(False) + self.widgets['resolution'].blockSignals(False) + + for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'image_refs', 'audio_source']: + if name in self.widgets: self.widgets[name].clear() + + guidance_layout = self.widgets['guidance_layout'] + guidance_max = model_def.get("guidance_max_phases", 1) + guidance_layout.setRowVisible(self.widgets['guidance_phases_row_index'], guidance_max > 1) + + adv_general_layout = self.widgets['adv_general_layout'] + adv_general_layout.setRowVisible(self.widgets['flow_shift_row_index'], not image_outputs) + adv_general_layout.setRowVisible(self.widgets['audio_guidance_row_index'], any_audio_guidance) + adv_general_layout.setRowVisible(self.widgets['repeat_generation_row_index'], not image_outputs) + adv_general_layout.setRowVisible(self.widgets['multi_images_gen_type_row_index'], i2v) + + self.widgets['slg_group'].setVisible(any_skip_layer_guidance) + quality_form_layout = self.widgets['quality_form_layout'] + quality_form_layout.setRowVisible(self.widgets['apg_switch_row_index'], any_apg) + quality_form_layout.setRowVisible(self.widgets['cfg_star_switch_row_index'], any_cfg_star) + quality_form_layout.setRowVisible(self.widgets['cfg_zero_step_row_index'], any_cfg_zero) + quality_form_layout.setRowVisible(self.widgets['min_frames_if_references_row_index'], v2i_switch_supported and image_outputs) + + self.widgets['advanced_tabs'].setTabVisible(self.widgets['sliding_window_tab_index'], sliding_window_enabled and not image_outputs) + + misc_layout = self.widgets['misc_layout'] + misc_layout.setRowVisible(self.widgets['riflex_row_index'], not (recammaster or ltxv or diffusion_forcing)) + + index = self.widgets['multi_images_gen_type'].findData(ui_defaults.get('multi_images_gen_type', 0)) + if index != -1: self.widgets['multi_images_gen_type'].setCurrentIndex(index) + + guidance_val = ui_defaults.get("guidance_scale", 5.0) + self.widgets['guidance_scale'].setValue(int(guidance_val * 10)) + self.widgets['guidance_scale_label'].setText(f"{guidance_val:.1f}") + + guidance2_val = ui_defaults.get("guidance2_scale", 5.0) + self.widgets['guidance2_scale'].setValue(int(guidance2_val * 10)) + self.widgets['guidance2_scale_label'].setText(f"{guidance2_val:.1f}") + + guidance3_val = ui_defaults.get("guidance3_scale", 5.0) + self.widgets['guidance3_scale'].setValue(int(guidance3_val * 10)) + self.widgets['guidance3_scale_label'].setText(f"{guidance3_val:.1f}") + + self.widgets['guidance_phases'].clear() + if guidance_max >= 1: self.widgets['guidance_phases'].addItem("One Phase", 1) + if guidance_max >= 2: self.widgets['guidance_phases'].addItem("Two Phases", 2) + if guidance_max >= 3: self.widgets['guidance_phases'].addItem("Three Phases", 3) + index = self.widgets['guidance_phases'].findData(ui_defaults.get("guidance_phases", 1)) + if index != -1: self.widgets['guidance_phases'].setCurrentIndex(index) + + switch_thresh_val = ui_defaults.get("switch_threshold", 0) + self.widgets['switch_threshold'].setValue(switch_thresh_val) + self.widgets['switch_threshold_label'].setText(str(switch_thresh_val)) + + nag_scale_val = ui_defaults.get('NAG_scale', 1.0) + self.widgets['NAG_scale'].setValue(int(nag_scale_val * 10)) + self.widgets['NAG_scale_label'].setText(f"{nag_scale_val:.1f}") + + nag_tau_val = ui_defaults.get('NAG_tau', 3.5) + self.widgets['NAG_tau'].setValue(int(nag_tau_val * 10)) + self.widgets['NAG_tau_label'].setText(f"{nag_tau_val:.1f}") + + nag_alpha_val = ui_defaults.get('NAG_alpha', 0.5) + self.widgets['NAG_alpha'].setValue(int(nag_alpha_val * 10)) + self.widgets['NAG_alpha_label'].setText(f"{nag_alpha_val:.1f}") + + self.widgets['nag_group'].setVisible(vace or t2v or i2v) + + self.widgets['sample_solver'].clear() + sampler_choices = model_def.get("sample_solvers", []) + self.widgets['solver_row_container'].setVisible(bool(sampler_choices)) + if sampler_choices: + for label, value in sampler_choices: self.widgets['sample_solver'].addItem(label, value) + solver_val = ui_defaults.get('sample_solver', sampler_choices[0][1]) + index = self.widgets['sample_solver'].findData(solver_val) + if index != -1: self.widgets['sample_solver'].setCurrentIndex(index) + + flow_val = ui_defaults.get("flow_shift", 3.0) + self.widgets['flow_shift'].setValue(int(flow_val * 10)) + self.widgets['flow_shift_label'].setText(f"{flow_val:.1f}") + + audio_guidance_val = ui_defaults.get("audio_guidance_scale", 4.0) + self.widgets['audio_guidance_scale'].setValue(int(audio_guidance_val * 10)) + self.widgets['audio_guidance_scale_label'].setText(f"{audio_guidance_val:.1f}") + + repeat_val = ui_defaults.get("repeat_generation", 1) + self.widgets['repeat_generation'].setValue(repeat_val) + self.widgets['repeat_generation_label'].setText(str(repeat_val)) + + available_loras, _, _, _, _, _ = wgp.setup_loras(model_type, None, wgp.get_lora_dir(model_type), "") + self.state['loras'] = available_loras + self.lora_map = {os.path.basename(p): p for p in available_loras} + lora_list_widget = self.widgets['activated_loras'] + lora_list_widget.clear() + lora_list_widget.addItems(sorted(self.lora_map.keys())) + selected_loras = ui_defaults.get('activated_loras', []) + for i in range(lora_list_widget.count()): + item = lora_list_widget.item(i) + if any(item.text() == os.path.basename(p) for p in selected_loras): item.setSelected(True) + self.widgets['loras_multipliers'].setText(ui_defaults.get('loras_multipliers', '')) + + skip_cache_val = ui_defaults.get('skip_steps_cache_type', "") + index = self.widgets['skip_steps_cache_type'].findData(skip_cache_val) + if index != -1: self.widgets['skip_steps_cache_type'].setCurrentIndex(index) + + skip_mult = ui_defaults.get('skip_steps_multiplier', 1.5) + index = self.widgets['skip_steps_multiplier'].findData(skip_mult) + if index != -1: self.widgets['skip_steps_multiplier'].setCurrentIndex(index) + + skip_perc_val = ui_defaults.get('skip_steps_start_step_perc', 0) + self.widgets['skip_steps_start_step_perc'].setValue(skip_perc_val) + self.widgets['skip_steps_start_step_perc_label'].setText(str(skip_perc_val)) + + temp_up_val = ui_defaults.get('temporal_upsampling', "") + index = self.widgets['temporal_upsampling'].findData(temp_up_val) + if index != -1: self.widgets['temporal_upsampling'].setCurrentIndex(index) + + spat_up_val = ui_defaults.get('spatial_upsampling', "") + index = self.widgets['spatial_upsampling'].findData(spat_up_val) + if index != -1: self.widgets['spatial_upsampling'].setCurrentIndex(index) + + film_grain_i = ui_defaults.get('film_grain_intensity', 0) + self.widgets['film_grain_intensity'].setValue(int(film_grain_i * 100)) + self.widgets['film_grain_intensity_label'].setText(f"{film_grain_i:.2f}") + + film_grain_s = ui_defaults.get('film_grain_saturation', 0.5) + self.widgets['film_grain_saturation'].setValue(int(film_grain_s * 100)) + self.widgets['film_grain_saturation_label'].setText(f"{film_grain_s:.2f}") + + self.widgets['MMAudio_setting'].setCurrentIndex(ui_defaults.get('MMAudio_setting', 0)) + self.widgets['MMAudio_prompt'].setText(ui_defaults.get('MMAudio_prompt', '')) + self.widgets['MMAudio_neg_prompt'].setText(ui_defaults.get('MMAudio_neg_prompt', '')) + + self.widgets['slg_switch'].setCurrentIndex(ui_defaults.get('slg_switch', 0)) + slg_start_val = ui_defaults.get('slg_start_perc', 10) + self.widgets['slg_start_perc'].setValue(slg_start_val) + self.widgets['slg_start_perc_label'].setText(str(slg_start_val)) + slg_end_val = ui_defaults.get('slg_end_perc', 90) + self.widgets['slg_end_perc'].setValue(slg_end_val) + self.widgets['slg_end_perc_label'].setText(str(slg_end_val)) + + self.widgets['apg_switch'].setCurrentIndex(ui_defaults.get('apg_switch', 0)) + self.widgets['cfg_star_switch'].setCurrentIndex(ui_defaults.get('cfg_star_switch', 0)) + + cfg_zero_val = ui_defaults.get('cfg_zero_step', -1) + self.widgets['cfg_zero_step'].setValue(cfg_zero_val) + self.widgets['cfg_zero_step_label'].setText(str(cfg_zero_val)) + + min_frames_val = ui_defaults.get('min_frames_if_references', 1) + index = self.widgets['min_frames_if_references'].findData(min_frames_val) + if index != -1: self.widgets['min_frames_if_references'].setCurrentIndex(index) + + self.widgets['RIFLEx_setting'].setCurrentIndex(ui_defaults.get('RIFLEx_setting', 0)) + + fps = wgp.get_model_fps(model_type) + force_fps_choices = [ + (f"Model Default ({fps} fps)", ""), ("Auto", "auto"), ("Control Video fps", "control"), + ("Source Video fps", "source"), ("15", "15"), ("16", "16"), ("23", "23"), + ("24", "24"), ("25", "25"), ("30", "30") + ] + self.widgets['force_fps'].clear() + for label, value in force_fps_choices: self.widgets['force_fps'].addItem(label, value) + force_fps_val = ui_defaults.get('force_fps', "") + index = self.widgets['force_fps'].findData(force_fps_val) + if index != -1: self.widgets['force_fps'].setCurrentIndex(index) + + override_prof_val = ui_defaults.get('override_profile', -1) + index = self.widgets['override_profile'].findData(override_prof_val) + if index != -1: self.widgets['override_profile'].setCurrentIndex(index) + + self.widgets['multi_prompts_gen_type'].setCurrentIndex(ui_defaults.get('multi_prompts_gen_type', 0)) + + denoising_val = ui_defaults.get("denoising_strength", 0.5) + self.widgets['denoising_strength'].setValue(int(denoising_val * 100)) + self.widgets['denoising_strength_label'].setText(f"{denoising_val:.2f}") + + sw_size = ui_defaults.get("sliding_window_size", 129) + self.widgets['sliding_window_size'].setValue(sw_size) + self.widgets['sliding_window_size_label'].setText(str(sw_size)) + + sw_overlap = ui_defaults.get("sliding_window_overlap", 5) + self.widgets['sliding_window_overlap'].setValue(sw_overlap) + self.widgets['sliding_window_overlap_label'].setText(str(sw_overlap)) + + sw_color = ui_defaults.get("sliding_window_color_correction_strength", 0) + self.widgets['sliding_window_color_correction_strength'].setValue(int(sw_color * 100)) + self.widgets['sliding_window_color_correction_strength_label'].setText(f"{sw_color:.2f}") + + sw_noise = ui_defaults.get("sliding_window_overlap_noise", 20) + self.widgets['sliding_window_overlap_noise'].setValue(sw_noise) + self.widgets['sliding_window_overlap_noise_label'].setText(str(sw_noise)) + + sw_discard = ui_defaults.get("sliding_window_discard_last_frames", 0) + self.widgets['sliding_window_discard_last_frames'].setValue(sw_discard) + self.widgets['sliding_window_discard_last_frames_label'].setText(str(sw_discard)) + + for widget in self.widgets.values(): + if hasattr(widget, 'blockSignals'): widget.blockSignals(False) + + self._update_dynamic_ui() + self._update_input_visibility() + + def _update_dynamic_ui(self): + phases = self.widgets['guidance_phases'].currentData() or 1 + guidance_layout = self.widgets['guidance_layout'] + guidance_layout.setRowVisible(self.widgets['guidance2_row_index'], phases >= 2) + guidance_layout.setRowVisible(self.widgets['guidance3_row_index'], phases >= 3) + guidance_layout.setRowVisible(self.widgets['switch_thresh_row_index'], phases >= 2) + + def _update_generation_mode_visibility(self, model_def): + allowed = model_def.get("image_prompt_types_allowed", "") + choices = [] + if "T" in allowed or not allowed: choices.append(("Text Prompt Only" if "S" in allowed else "New Video", "T")) + if "S" in allowed: choices.append(("Start Video with Image", "S")) + if "V" in allowed: choices.append(("Continue Video", "V")) + if "L" in allowed: choices.append(("Continue Last Video", "L")) + button_map = { "T": self.widgets['mode_t'], "S": self.widgets['mode_s'], "V": self.widgets['mode_v'], "L": self.widgets['mode_l'] } + for btn in button_map.values(): btn.setVisible(False) + allowed_values = [c[1] for c in choices] + for label, value in choices: + if value in button_map: + btn = button_map[value] + btn.setText(label) + btn.setVisible(True) + current_checked_value = next((value for value, btn in button_map.items() if btn.isChecked()), None) + if current_checked_value is None or not button_map[current_checked_value].isVisible(): + if allowed_values: button_map[allowed_values[0]].setChecked(True) + end_image_visible = "E" in allowed + self.widgets['image_end_checkbox'].setVisible(end_image_visible) + if not end_image_visible: self.widgets['image_end_checkbox'].setChecked(False) + control_video_visible = model_def.get("guide_preprocessing") is not None + self.widgets['control_video_checkbox'].setVisible(control_video_visible) + if not control_video_visible: self.widgets['control_video_checkbox'].setChecked(False) + ref_image_visible = model_def.get("image_ref_choices") is not None + self.widgets['ref_image_checkbox'].setVisible(ref_image_visible) + if not ref_image_visible: self.widgets['ref_image_checkbox'].setChecked(False) + + def _update_input_visibility(self): + is_s_mode = self.widgets['mode_s'].isChecked() + is_v_mode = self.widgets['mode_v'].isChecked() + is_l_mode = self.widgets['mode_l'].isChecked() + use_end = self.widgets['image_end_checkbox'].isChecked() and self.widgets['image_end_checkbox'].isVisible() + use_control = self.widgets['control_video_checkbox'].isChecked() and self.widgets['control_video_checkbox'].isVisible() + use_ref = self.widgets['ref_image_checkbox'].isChecked() and self.widgets['ref_image_checkbox'].isVisible() + self.widgets['image_start_container'].setVisible(is_s_mode) + self.widgets['video_source_container'].setVisible(is_v_mode) + end_checkbox_enabled = is_s_mode or is_v_mode or is_l_mode + self.widgets['image_end_checkbox'].setEnabled(end_checkbox_enabled) + self.widgets['image_end_container'].setVisible(use_end and end_checkbox_enabled) + self.widgets['video_guide_container'].setVisible(use_control) + self.widgets['video_mask_container'].setVisible(use_control) + self.widgets['image_refs_container'].setVisible(use_ref) + + def connect_signals(self): + self.widgets['model_family'].currentIndexChanged.connect(self._on_family_changed) + self.widgets['model_base_type_choice'].currentIndexChanged.connect(self._on_base_type_changed) + self.widgets['model_choice'].currentIndexChanged.connect(self._on_model_changed) + self.widgets['resolution_group'].currentIndexChanged.connect(self._on_resolution_group_changed) + self.widgets['guidance_phases'].currentIndexChanged.connect(self._update_dynamic_ui) + self.widgets['mode_t'].toggled.connect(self._update_input_visibility) + self.widgets['mode_s'].toggled.connect(self._update_input_visibility) + self.widgets['mode_v'].toggled.connect(self._update_input_visibility) + self.widgets['mode_l'].toggled.connect(self._update_input_visibility) + self.widgets['image_end_checkbox'].toggled.connect(self._update_input_visibility) + self.widgets['control_video_checkbox'].toggled.connect(self._update_input_visibility) + self.widgets['ref_image_checkbox'].toggled.connect(self._update_input_visibility) + self.widgets['preview_group'].toggled.connect(self._on_preview_toggled) + self.generate_btn.clicked.connect(self._on_generate) + self.add_to_queue_btn.clicked.connect(self._on_add_to_queue) + self.remove_queue_btn.clicked.connect(self._on_remove_selected_from_queue) + self.clear_queue_btn.clicked.connect(self._on_clear_queue) + self.abort_btn.clicked.connect(self._on_abort) + self.queue_table.rowsMoved.connect(self._on_queue_rows_moved) + + def load_main_config(self): + try: + with open('main_config.json', 'r') as f: self.main_config = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): self.main_config = {'preview_visible': False} + + def save_main_config(self): + try: + with open('main_config.json', 'w') as f: json.dump(self.main_config, f, indent=4) + except Exception as e: print(f"Error saving main_config.json: {e}") + + def apply_initial_config(self): + is_visible = self.main_config.get('preview_visible', True) + self.widgets['preview_group'].setChecked(is_visible) + self.widgets['preview_image'].setVisible(is_visible) + + def _on_preview_toggled(self, checked): + self.widgets['preview_image'].setVisible(checked) + self.main_config['preview_visible'] = checked + self.save_main_config() + + def _on_family_changed(self): + family = self.widgets['model_family'].currentData() + if not family or not self.state: return + base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family) + + if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict): + is_visible_base = base_type_mock.kwargs.get('visible', True) + elif hasattr(base_type_mock, 'visible'): + is_visible_base = base_type_mock.visible + else: + is_visible_base = True + + self.widgets['model_base_type_choice'].blockSignals(True) + self.widgets['model_base_type_choice'].clear() + if base_type_mock.choices: + for label, value in base_type_mock.choices: self.widgets['model_base_type_choice'].addItem(label, value) + self.widgets['model_base_type_choice'].setCurrentIndex(self.widgets['model_base_type_choice'].findData(base_type_mock.value)) + self.widgets['model_base_type_choice'].setVisible(is_visible_base) + self.widgets['model_base_type_choice'].blockSignals(False) + + if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict): + is_visible_choice = choice_mock.kwargs.get('visible', True) + elif hasattr(choice_mock, 'visible'): + is_visible_choice = choice_mock.visible + else: + is_visible_choice = True + + self.widgets['model_choice'].blockSignals(True) + self.widgets['model_choice'].clear() + if choice_mock.choices: + for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) + self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value)) + self.widgets['model_choice'].setVisible(is_visible_choice) + self.widgets['model_choice'].blockSignals(False) + + self._on_model_changed() + + def _on_base_type_changed(self): + family = self.widgets['model_family'].currentData() + base_type = self.widgets['model_base_type_choice'].currentData() + if not family or not base_type or not self.state: return + base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type) + + if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict): + is_visible_choice = choice_mock.kwargs.get('visible', True) + elif hasattr(choice_mock, 'visible'): + is_visible_choice = choice_mock.visible + else: + is_visible_choice = True + + self.widgets['model_choice'].blockSignals(True) + self.widgets['model_choice'].clear() + if choice_mock.choices: + for label, value in choice_mock.choices: self.widgets['model_choice'].addItem(label, value) + self.widgets['model_choice'].setCurrentIndex(self.widgets['model_choice'].findData(choice_mock.value)) + self.widgets['model_choice'].setVisible(is_visible_choice) + self.widgets['model_choice'].blockSignals(False) + self._on_model_changed() + + def _on_model_changed(self): + model_type = self.widgets['model_choice'].currentData() + if not model_type or model_type == self.state.get('model_type'): return + self.wgp.change_model(self.state, model_type) + self.refresh_ui_from_model_change(model_type) + + def _on_resolution_group_changed(self): + selected_group = self.widgets['resolution_group'].currentText() + if not selected_group or not hasattr(self, 'full_resolution_choices'): return + model_type = self.state['model_type'] + model_def = self.wgp.get_model_def(model_type) + model_resolutions = model_def.get("resolutions", None) + group_resolution_choices = [] + if model_resolutions is None: + group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group] + else: return + last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "") + if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices: + last_resolution = group_resolution_choices[0][1] + self.widgets['resolution'].blockSignals(True) + self.widgets['resolution'].clear() + for label, value in group_resolution_choices: self.widgets['resolution'].addItem(label, value) + self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution)) + self.widgets['resolution'].blockSignals(False) + + def collect_inputs(self): + full_inputs = self.wgp.get_current_model_settings(self.state).copy() + full_inputs['lset_name'] = "" + full_inputs['image_mode'] = 0 + expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, } + for key, default_value in expected_keys.items(): + if key not in full_inputs: full_inputs[key] = default_value + full_inputs['prompt'] = self.widgets['prompt'].toPlainText() + full_inputs['negative_prompt'] = self.widgets['negative_prompt'].toPlainText() + full_inputs['resolution'] = self.widgets['resolution'].currentData() + full_inputs['video_length'] = self.widgets['video_length'].value() + full_inputs['num_inference_steps'] = self.widgets['num_inference_steps'].value() + full_inputs['seed'] = int(self.widgets['seed'].text()) + image_prompt_type = "" + video_prompt_type = "" + if self.widgets['mode_s'].isChecked(): image_prompt_type = 'S' + elif self.widgets['mode_v'].isChecked(): image_prompt_type = 'V' + elif self.widgets['mode_l'].isChecked(): image_prompt_type = 'L' + if self.widgets['image_end_checkbox'].isVisible() and self.widgets['image_end_checkbox'].isChecked(): image_prompt_type += 'E' + if self.widgets['control_video_checkbox'].isVisible() and self.widgets['control_video_checkbox'].isChecked(): video_prompt_type += 'V' + if self.widgets['ref_image_checkbox'].isVisible() and self.widgets['ref_image_checkbox'].isChecked(): video_prompt_type += 'I' + full_inputs['image_prompt_type'] = image_prompt_type + full_inputs['video_prompt_type'] = video_prompt_type + for name in ['video_source', 'image_start', 'image_end', 'video_guide', 'video_mask', 'audio_source']: + if name in self.widgets: full_inputs[name] = self.widgets[name].text() or None + paths = self.widgets['image_refs'].text().split(';') + full_inputs['image_refs'] = [p.strip() for p in paths if p.strip()] if paths and paths[0] else None + full_inputs['denoising_strength'] = self.widgets['denoising_strength'].value() / 100.0 + if self.advanced_group.isChecked(): + full_inputs['guidance_scale'] = self.widgets['guidance_scale'].value() / 10.0 + full_inputs['guidance_phases'] = self.widgets['guidance_phases'].currentData() + full_inputs['guidance2_scale'] = self.widgets['guidance2_scale'].value() / 10.0 + full_inputs['guidance3_scale'] = self.widgets['guidance3_scale'].value() / 10.0 + full_inputs['switch_threshold'] = self.widgets['switch_threshold'].value() + full_inputs['NAG_scale'] = self.widgets['NAG_scale'].value() / 10.0 + full_inputs['NAG_tau'] = self.widgets['NAG_tau'].value() / 10.0 + full_inputs['NAG_alpha'] = self.widgets['NAG_alpha'].value() / 10.0 + full_inputs['sample_solver'] = self.widgets['sample_solver'].currentData() + full_inputs['flow_shift'] = self.widgets['flow_shift'].value() / 10.0 + full_inputs['audio_guidance_scale'] = self.widgets['audio_guidance_scale'].value() / 10.0 + full_inputs['repeat_generation'] = self.widgets['repeat_generation'].value() + full_inputs['multi_images_gen_type'] = self.widgets['multi_images_gen_type'].currentData() + selected_items = self.widgets['activated_loras'].selectedItems() + full_inputs['activated_loras'] = [self.lora_map[item.text()] for item in selected_items if item.text() in self.lora_map] + full_inputs['loras_multipliers'] = self.widgets['loras_multipliers'].toPlainText() + full_inputs['skip_steps_cache_type'] = self.widgets['skip_steps_cache_type'].currentData() + full_inputs['skip_steps_multiplier'] = self.widgets['skip_steps_multiplier'].currentData() + full_inputs['skip_steps_start_step_perc'] = self.widgets['skip_steps_start_step_perc'].value() + full_inputs['temporal_upsampling'] = self.widgets['temporal_upsampling'].currentData() + full_inputs['spatial_upsampling'] = self.widgets['spatial_upsampling'].currentData() + full_inputs['film_grain_intensity'] = self.widgets['film_grain_intensity'].value() / 100.0 + full_inputs['film_grain_saturation'] = self.widgets['film_grain_saturation'].value() / 100.0 + full_inputs['MMAudio_setting'] = self.widgets['MMAudio_setting'].currentData() + full_inputs['MMAudio_prompt'] = self.widgets['MMAudio_prompt'].text() + full_inputs['MMAudio_neg_prompt'] = self.widgets['MMAudio_neg_prompt'].text() + full_inputs['RIFLEx_setting'] = self.widgets['RIFLEx_setting'].currentData() + full_inputs['force_fps'] = self.widgets['force_fps'].currentData() + full_inputs['override_profile'] = self.widgets['override_profile'].currentData() + full_inputs['multi_prompts_gen_type'] = self.widgets['multi_prompts_gen_type'].currentData() + full_inputs['slg_switch'] = self.widgets['slg_switch'].currentData() + full_inputs['slg_start_perc'] = self.widgets['slg_start_perc'].value() + full_inputs['slg_end_perc'] = self.widgets['slg_end_perc'].value() + full_inputs['apg_switch'] = self.widgets['apg_switch'].currentData() + full_inputs['cfg_star_switch'] = self.widgets['cfg_star_switch'].currentData() + full_inputs['cfg_zero_step'] = self.widgets['cfg_zero_step'].value() + full_inputs['min_frames_if_references'] = self.widgets['min_frames_if_references'].currentData() + full_inputs['sliding_window_size'] = self.widgets['sliding_window_size'].value() + full_inputs['sliding_window_overlap'] = self.widgets['sliding_window_overlap'].value() + full_inputs['sliding_window_color_correction_strength'] = self.widgets['sliding_window_color_correction_strength'].value() / 100.0 + full_inputs['sliding_window_overlap_noise'] = self.widgets['sliding_window_overlap_noise'].value() + full_inputs['sliding_window_discard_last_frames'] = self.widgets['sliding_window_discard_last_frames'].value() + return full_inputs + + def _prepare_state_for_generation(self): + if 'gen' in self.state: + self.state['gen'].pop('abort', None) + self.state['gen'].pop('in_progress', None) + + def _on_generate(self): + try: + is_running = self.thread and self.thread.isRunning() + self._add_task_to_queue() + if not is_running: self.start_generation() + except Exception as e: + import traceback; traceback.print_exc() + + def _on_add_to_queue(self): + try: + self._add_task_to_queue() + except Exception as e: + import traceback; traceback.print_exc() + + def _add_task_to_queue(self): + all_inputs = self.collect_inputs() + for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None) + all_inputs['state'] = self.state + self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs) + self.state["validate_success"] = 1 + self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type']) + self.update_queue_table() + + def start_generation(self): + if not self.state['gen']['queue']: return + self._prepare_state_for_generation() + self.generate_btn.setEnabled(False) + self.add_to_queue_btn.setEnabled(True) + self.thread = QThread() + self.worker = Worker(self.plugin, self.state) + self.worker.moveToThread(self.thread) + self.thread.started.connect(self.worker.run) + self.worker.finished.connect(self.thread.quit) + self.worker.finished.connect(self.worker.deleteLater) + self.thread.finished.connect(self.thread.deleteLater) + self.thread.finished.connect(self.on_generation_finished) + self.worker.status.connect(self.status_label.setText) + self.worker.progress.connect(self.update_progress) + self.worker.preview.connect(self.update_preview) + self.worker.output.connect(self.update_queue_and_results) + self.worker.error.connect(self.on_generation_error) + self.thread.start() + self.update_queue_table() + + def on_generation_finished(self): + time.sleep(0.1) + self.status_label.setText("Finished.") + self.progress_bar.setValue(0) + self.generate_btn.setEnabled(True) + self.add_to_queue_btn.setEnabled(False) + self.thread = None; self.worker = None + self.update_queue_table() + + def on_generation_error(self, err_msg): + QMessageBox.critical(self, "Generation Error", str(err_msg)) + self.on_generation_finished() + + def update_progress(self, data): + if len(data) > 1 and isinstance(data[0], tuple): + step, total = data[0] + self.progress_bar.setMaximum(total) + self.progress_bar.setValue(step) + self.status_label.setText(str(data[1])) + if step <= 1: self.update_queue_table() + elif len(data) > 1: self.status_label.setText(str(data[1])) + + def update_preview(self, pil_image): + if pil_image and self.widgets['preview_group'].isChecked(): + q_image = ImageQt(pil_image) + pixmap = QPixmap.fromImage(q_image) + self.preview_image.setPixmap(pixmap.scaled(self.preview_image.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)) + + def update_queue_and_results(self): + self.update_queue_table() + file_list = self.state.get('gen', {}).get('file_list', []) + for file_path in file_list: + if file_path not in self.processed_files: + self.add_result_item(file_path) + self.processed_files.add(file_path) + + def add_result_item(self, video_path): + item_widget = VideoResultItemWidget(video_path, self.plugin) + list_item = QListWidgetItem(self.results_list) + list_item.setSizeHint(item_widget.sizeHint()) + self.results_list.addItem(list_item) + self.results_list.setItemWidget(list_item, item_widget) + + def update_queue_table(self): + with self.wgp.lock: + queue = self.state.get('gen', {}).get('queue', []) + is_running = self.thread and self.thread.isRunning() + queue_to_display = queue if is_running else [None] + queue + table_data = self.wgp.get_queue_table(queue_to_display) + self.queue_table.setRowCount(0) + self.queue_table.setRowCount(len(table_data)) + self.queue_table.setColumnCount(4) + self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"]) + for row_idx, row_data in enumerate(table_data): + prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1]) + for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]): + self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data))) + self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.queue_table.resizeColumnsToContents() + + def _on_remove_selected_from_queue(self): + selected_row = self.queue_table.currentRow() + if selected_row < 0: return + with self.wgp.lock: + is_running = self.thread and self.thread.isRunning() + offset = 1 if is_running else 0 + queue = self.state.get('gen', {}).get('queue', []) + if len(queue) > selected_row + offset: queue.pop(selected_row + offset) + self.update_queue_table() + + def _on_queue_rows_moved(self, source_row, dest_row): + with self.wgp.lock: + queue = self.state.get('gen', {}).get('queue', []) + is_running = self.thread and self.thread.isRunning() + offset = 1 if is_running else 0 + real_source_idx = source_row + offset + real_dest_idx = dest_row + offset + moved_item = queue.pop(real_source_idx) + queue.insert(real_dest_idx, moved_item) + self.update_queue_table() + + def _on_clear_queue(self): + self.wgp.clear_queue_action(self.state) + self.update_queue_table() + + def _on_abort(self): + if self.worker: + self.wgp.abort_generation(self.state) + self.status_label.setText("Aborting...") + self.worker._is_running = False + + def _on_release_ram(self): + self.wgp.release_RAM() + QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.") + + def _on_apply_config_changes(self): + changes = {} + list_widget = self.widgets['config_transformer_types'] + changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] + list_widget = self.widgets['config_preload_model_policy'] + changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] + changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData() + changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText() + for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]: + changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() + for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]: + changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() + changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value() + for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]: + changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() + for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]: + widget = self.widgets[f'config_{key}'] + changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text() + changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData() + changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value() + changes['last_resolution_choice'] = self.widgets['resolution'].currentData() + try: + msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes) + self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.") + self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode)) + if family_mock.choices is not None or choice_mock.choices is not None: + self.update_model_dropdowns(self.wgp.transformer_type) + self.refresh_ui_from_model_change(self.wgp.transformer_type) + except Exception as e: + self.config_status_label.setText(f"Error applying changes: {e}") + import traceback; traceback.print_exc() + +class Plugin(VideoEditorPlugin): + def initialize(self): + self.name = "AI Generator" + self.description = "Uses the integrated Wan2GP library to generate video clips." + self.client_widget = WgpDesktopPluginWidget(self) + self.dock_widget = None + self.active_region = None; self.temp_dir = None + self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None + + def enable(self): + if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name) + self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu) + self.app.status_label.setText(f"{self.name}: Enabled.") + + def disable(self): + try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu) + except TypeError: pass + self._cleanup_temp_dir() + if self.client_widget.worker: self.client_widget._on_abort() + self.app.status_label.setText(f"{self.name}: Disabled.") + + def _cleanup_temp_dir(self): + if self.temp_dir and os.path.exists(self.temp_dir): + shutil.rmtree(self.temp_dir) + self.temp_dir = None + + def _reset_state(self): + self.active_region = None; self.insert_on_new_track = False + self.start_frame_path = None; self.end_frame_path = None + self.client_widget.processed_files.clear() + self.client_widget.results_list.clear() + self.client_widget.widgets['image_start'].clear() + self.client_widget.widgets['image_end'].clear() + self.client_widget.widgets['video_source'].clear() + self._cleanup_temp_dir() + self.app.status_label.setText(f"{self.name}: Ready.") + + def on_timeline_context_menu(self, menu, event): + region = self.app.timeline_widget.get_region_at_pos(event.pos()) + if region: + menu.addSeparator() + start_sec, end_sec = region + start_data, _, _ = self.app.get_frame_data_at_time(start_sec) + end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + if start_data and end_data: + join_action = menu.addAction("Join Frames With AI") + join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) + join_action_new_track = menu.addAction("Join Frames With AI (New Track)") + join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) + create_action = menu.addAction("Create Frames With AI") + create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False)) + create_action_new_track = menu.addAction("Create Frames With AI (New Track)") + create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True)) + + def setup_generator_for_region(self, region, on_new_track=False): + self._reset_state() + self.active_region = region + self.insert_on_new_track = on_new_track + + model_to_set = 'i2v_2_2' + dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types + _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + if any(model_to_set == m[1] for m in all_models): + if self.client_widget.state.get('model_type') != model_to_set: + self.client_widget.update_model_dropdowns(model_to_set) + self.client_widget._on_model_changed() + else: + print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") + + start_sec, end_sec = region + start_data, w, h = self.app.get_frame_data_at_time(start_sec) + end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + if not start_data or not end_data: + QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.") + return + try: + self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_") + self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png") + self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png") + QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path) + QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) + + duration_sec = end_sec - start_sec + wgp = self.client_widget.wgp + model_type = self.client_widget.state['model_type'] + fps = wgp.get_model_fps(model_type) + video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + + self.client_widget.widgets['video_length'].setValue(video_length_frames) + self.client_widget.widgets['mode_s'].setChecked(True) + self.client_widget.widgets['image_end_checkbox'].setChecked(True) + self.client_widget.widgets['image_start'].setText(self.start_frame_path) + self.client_widget.widgets['image_end'].setText(self.end_frame_path) + + self.client_widget._update_input_visibility() + + except Exception as e: + QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}") + self._cleanup_temp_dir() + return + self.app.status_label.setText(f"Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.") + self.dock_widget.show() + self.dock_widget.raise_() + + def setup_creator_for_region(self, region, on_new_track=False): + self._reset_state() + self.active_region = region + self.insert_on_new_track = on_new_track + + model_to_set = 't2v_2_2' + dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types + _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + if any(model_to_set == m[1] for m in all_models): + if self.client_widget.state.get('model_type') != model_to_set: + self.client_widget.update_model_dropdowns(model_to_set) + self.client_widget._on_model_changed() + else: + print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.") + + start_sec, end_sec = region + duration_sec = end_sec - start_sec + wgp = self.client_widget.wgp + model_type = self.client_widget.state['model_type'] + fps = wgp.get_model_fps(model_type) + video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + + self.client_widget.widgets['video_length'].setValue(video_length_frames) + self.client_widget.widgets['mode_t'].setChecked(True) + + self.app.status_label.setText(f"Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.") + self.dock_widget.show() + self.dock_widget.raise_() + + def insert_generated_clip(self, video_path): + from videoeditor import TimelineClip + if not self.active_region: + self.app.status_label.setText("Error: No active region to insert into."); return + if not os.path.exists(video_path): + self.app.status_label.setText(f"Error: Output file not found: {video_path}"); return + start_sec, end_sec = self.active_region + def complex_insertion_action(): + self.app._add_media_files_to_project([video_path]) + media_info = self.app.media_properties.get(video_path) + if not media_info: raise ValueError("Could not probe inserted clip.") + actual_duration, has_audio = media_info['duration'], media_info['has_audio'] + if self.insert_on_new_track: + self.app.timeline.num_video_tracks += 1 + video_track_index = self.app.timeline.num_video_tracks + audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None + else: + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) + clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec] + for clip in clips_to_remove: + if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip) + video_track_index, audio_track_index = 1, 1 if has_audio else None + group_id = str(uuid.uuid4()) + new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id) + self.app.timeline.add_clip(new_clip) + if audio_track_index: + if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index + audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id) + self.app.timeline.add_clip(audio_clip) + try: + self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action) + self.app.prune_empty_tracks() + self.app.status_label.setText("AI clip inserted successfully.") + for i in range(self.client_widget.results_list.count()): + widget = self.client_widget.results_list.itemWidget(self.client_widget.results_list.item(i)) + if widget and widget.video_path == video_path: + self.client_widget.results_list.takeItem(i); break + except Exception as e: + import traceback; traceback.print_exc() + self.app.status_label.setText(f"Error during clip insertion: {e}") \ No newline at end of file From 556280398351454626f953169970dff560131556 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sat, 11 Oct 2025 23:39:55 +1100 Subject: [PATCH 125/155] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index eff58e558..81c93be78 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ python videoeditor.py ``` ## Screenshots -image +image image image From 19a5cd85583d5f38edc2dfa4437818d0407c4d8c Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 10:08:29 +1100 Subject: [PATCH 126/155] allow region selection resizing, remove all region context menu options showing up for single regions and vice versa --- videoeditor.py | 105 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 88 insertions(+), 17 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index 1c7466957..4a124b785 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -134,6 +134,10 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.resize_edge = None self.resize_start_pos = QPoint() + self.resizing_selection_region = None + self.resize_selection_edge = None + self.resize_selection_start_values = None + self.highlighted_track_info = None self.highlighted_ghost_track_info = None self.add_video_track_btn_rect = QRect() @@ -522,6 +526,29 @@ def mousePressEvent(self, event: QMouseEvent): self.resizing_clip = None self.resize_edge = None self.drag_original_clip_states.clear() + self.resizing_selection_region = None + self.resize_selection_edge = None + self.resize_selection_start_values = None + + for region in self.selection_regions: + if not region: continue + x_start = self.sec_to_x(region[0]) + x_end = self.sec_to_x(region[1]) + if event.pos().y() > self.TIMESCALE_HEIGHT: + if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH: + self.resizing_selection_region = region + self.resize_selection_edge = 'left' + break + elif abs(event.pos().x() - x_end) < self.RESIZE_HANDLE_WIDTH: + self.resizing_selection_region = region + self.resize_selection_edge = 'right' + break + + if self.resizing_selection_region: + self.resize_selection_start_values = tuple(self.resizing_selection_region) + self.drag_start_pos = event.pos() + self.update() + return for clip in reversed(self.timeline.clips): clip_rect = self.get_clip_rect(clip) @@ -595,6 +622,27 @@ def mousePressEvent(self, event: QMouseEvent): self.update() def mouseMoveEvent(self, event: QMouseEvent): + if self.resizing_selection_region: + current_sec = max(0, self.x_to_sec(event.pos().x())) + original_start, original_end = self.resize_selection_start_values + + if self.resize_selection_edge == 'left': + new_start = current_sec + new_end = original_end + else: # right + new_start = original_start + new_end = current_sec + + self.resizing_selection_region[0] = min(new_start, new_end) + self.resizing_selection_region[1] = max(new_start, new_end) + + if (self.resize_selection_edge == 'left' and new_start > new_end) or \ + (self.resize_selection_edge == 'right' and new_end < new_start): + self.resize_selection_edge = 'right' if self.resize_selection_edge == 'left' else 'left' + self.resize_selection_start_values = (original_end, original_start) + + self.update() + return if self.resizing_clip: linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None) delta_x = event.pos().x() - self.resize_start_pos.x() @@ -686,6 +734,15 @@ def mouseMoveEvent(self, event: QMouseEvent): self.setCursor(Qt.CursorShape.SizeHorCursor) cursor_set = True + if not cursor_set and is_in_track_area: + for region in self.selection_regions: + x_start = self.sec_to_x(region[0]) + x_end = self.sec_to_x(region[1]) + if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH or \ + abs(event.pos().x() - x_end) < self.RESIZE_HANDLE_WIDTH: + self.setCursor(Qt.CursorShape.SizeHorCursor) + cursor_set = True + break if not cursor_set: for clip in self.timeline.clips: clip_rect = self.get_clip_rect(clip) @@ -786,6 +843,12 @@ def mouseMoveEvent(self, event: QMouseEvent): def mouseReleaseEvent(self, event: QMouseEvent): if event.button() == Qt.MouseButton.LeftButton: + if self.resizing_selection_region: + self.resizing_selection_region = None + self.resize_selection_edge = None + self.resize_selection_start_values = None + self.update() + return if self.resizing_clip: new_state = self.window()._get_current_timeline_state() command = TimelineStateChangeCommand("Resize Clip", self.timeline, *self.drag_start_state, *new_state) @@ -1074,23 +1137,31 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'): region_at_pos = self.get_region_at_pos(event.pos()) if region_at_pos: - split_this_action = menu.addAction("Split This Region") - split_all_action = menu.addAction("Split All Regions") - join_this_action = menu.addAction("Join This Region") - join_all_action = menu.addAction("Join All Regions") - delete_this_action = menu.addAction("Delete This Region") - delete_all_action = menu.addAction("Delete All Regions") - menu.addSeparator() - clear_this_action = menu.addAction("Clear This Region") - clear_all_action = menu.addAction("Clear All Regions") - split_this_action.triggered.connect(lambda: self.split_region_requested.emit(region_at_pos)) - split_all_action.triggered.connect(lambda: self.split_all_regions_requested.emit(self.selection_regions)) - join_this_action.triggered.connect(lambda: self.join_region_requested.emit(region_at_pos)) - join_all_action.triggered.connect(lambda: self.join_all_regions_requested.emit(self.selection_regions)) - delete_this_action.triggered.connect(lambda: self.delete_region_requested.emit(region_at_pos)) - delete_all_action.triggered.connect(lambda: self.delete_all_regions_requested.emit(self.selection_regions)) - clear_this_action.triggered.connect(lambda: self.clear_region(region_at_pos)) - clear_all_action.triggered.connect(self.clear_all_regions) + num_regions = len(self.selection_regions) + + if num_regions == 1: + split_this_action = menu.addAction("Split Region") + join_this_action = menu.addAction("Join Region") + delete_this_action = menu.addAction("Delete Region") + menu.addSeparator() + clear_this_action = menu.addAction("Clear Region") + + split_this_action.triggered.connect(lambda: self.split_region_requested.emit(region_at_pos)) + join_this_action.triggered.connect(lambda: self.join_region_requested.emit(region_at_pos)) + delete_this_action.triggered.connect(lambda: self.delete_region_requested.emit(region_at_pos)) + clear_this_action.triggered.connect(lambda: self.clear_region(region_at_pos)) + + elif num_regions > 1: + split_all_action = menu.addAction("Split All Regions") + join_all_action = menu.addAction("Join All Regions") + delete_all_action = menu.addAction("Delete All Regions") + menu.addSeparator() + clear_all_action = menu.addAction("Clear All Regions") + + split_all_action.triggered.connect(lambda: self.split_all_regions_requested.emit(self.selection_regions)) + join_all_action.triggered.connect(lambda: self.join_all_regions_requested.emit(self.selection_regions)) + delete_all_action.triggered.connect(lambda: self.delete_all_regions_requested.emit(self.selection_regions)) + clear_all_action.triggered.connect(self.clear_all_regions) clip_at_pos = None for clip in self.timeline.clips: From e45796c6b977176293cbf44178d0c517bab74863 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 13:32:45 +1100 Subject: [PATCH 127/155] add preview video scaling --- videoeditor.py | 186 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 131 insertions(+), 55 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index 4a124b785..15e5f0d7d 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -1363,9 +1363,15 @@ def __init__(self, project_to_load=None): self.plugin_manager = PluginManager(self, wgp) self.plugin_manager.discover_and_load_plugins() + # Project settings with defaults self.project_fps = 25.0 self.project_width = 1280 self.project_height = 720 + + # Preview settings + self.scale_to_fit = True + self.current_preview_pixmap = None + self.playback_timer = QTimer(self) self.playback_process = None self.playback_clip = None @@ -1380,6 +1386,10 @@ def __init__(self, project_to_load=None): if not self.settings_file_was_loaded: self._save_settings() if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load)) + def resizeEvent(self, event): + super().resizeEvent(event) + self._update_preview_display() + def _get_current_timeline_state(self): return ( copy.deepcopy(self.timeline.clips), @@ -1395,12 +1405,18 @@ def _setup_ui(self): self.splitter = QSplitter(Qt.Orientation.Vertical) + # --- PREVIEW WIDGET SETUP (CHANGED) --- + self.preview_scroll_area = QScrollArea() + self.preview_scroll_area.setWidgetResizable(False) # Important for 1:1 scaling + self.preview_scroll_area.setStyleSheet("background-color: black; border: 0px;") + self.preview_widget = QLabel() self.preview_widget.setAlignment(Qt.AlignmentFlag.AlignCenter) self.preview_widget.setMinimumSize(640, 360) - self.preview_widget.setFrameShape(QFrame.Shape.Box) - self.preview_widget.setStyleSheet("background-color: black; color: white;") - self.splitter.addWidget(self.preview_widget) + self.preview_widget.setContextMenuPolicy(Qt.ContextMenuPolicy.CustomContextMenu) + + self.preview_scroll_area.setWidget(self.preview_widget) + self.splitter.addWidget(self.preview_scroll_area) self.timeline_widget = TimelineWidget(self.timeline, self.settings, self.project_fps, self) self.timeline_scroll_area = QScrollArea() @@ -1412,6 +1428,9 @@ def _setup_ui(self): self.timeline_scroll_area.setMinimumHeight(250) self.splitter.addWidget(self.timeline_scroll_area) + self.splitter.setStretchFactor(0, 1) + self.splitter.setStretchFactor(1, 0) + container_widget = QWidget() main_layout = QVBoxLayout(container_widget) main_layout.setContentsMargins(0,0,0,0) @@ -1445,7 +1464,7 @@ def _setup_ui(self): self.setCentralWidget(container_widget) self.managed_widgets = { - 'preview': {'widget': self.preview_widget, 'name': 'Video Preview', 'action': None}, + 'preview': {'widget': self.preview_scroll_area, 'name': 'Video Preview', 'action': None}, 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None}, 'project_media': {'widget': self.media_dock, 'name': 'Project Media', 'action': None} } @@ -1459,6 +1478,7 @@ def _setup_ui(self): def _connect_signals(self): self.splitter.splitterMoved.connect(self.on_splitter_moved) + self.preview_widget.customContextMenuRequested.connect(self._show_preview_context_menu) self.timeline_widget.split_requested.connect(self.split_clip_at_playhead) self.timeline_widget.delete_clip_requested.connect(self.delete_clip) @@ -1487,6 +1507,18 @@ def _connect_signals(self): self.undo_stack.history_changed.connect(self.update_undo_redo_actions) self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo) + def _show_preview_context_menu(self, pos): + menu = QMenu(self) + scale_action = QAction("Scale to Fit", self, checkable=True) + scale_action.setChecked(self.scale_to_fit) + scale_action.toggled.connect(self._toggle_scale_to_fit) + menu.addAction(scale_action) + menu.exec(self.preview_widget.mapToGlobal(pos)) + + def _toggle_scale_to_fit(self, checked): + self.scale_to_fit = checked + self._update_preview_display() + def on_timeline_changed_by_undo(self): self.prune_empty_tracks() self.timeline_widget.update() @@ -1644,7 +1676,7 @@ def _create_menu_bar(self): self.windows_menu = menu_bar.addMenu("&Windows") for key, data in self.managed_widgets.items(): - if data['widget'] is self.preview_widget: continue + if data['widget'] is self.preview_scroll_area: continue action = QAction(data['name'], self, checkable=True) if hasattr(data['widget'], 'visibilityChanged'): action.toggled.connect(data['widget'].setVisible) @@ -1700,9 +1732,12 @@ def _get_media_properties(self, file_path): media_info = {} if file_ext in ['.png', '.jpg', '.jpeg']: + img = QImage(file_path) media_info['media_type'] = 'image' media_info['duration'] = 5.0 media_info['has_audio'] = False + media_info['width'] = img.width() + media_info['height'] = img.height() else: probe = ffmpeg.probe(file_path) video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) @@ -1712,6 +1747,11 @@ def _get_media_properties(self, file_path): media_info['media_type'] = 'video' media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0))) media_info['has_audio'] = audio_stream is not None + media_info['width'] = int(video_stream['width']) + media_info['height'] = int(video_stream['height']) + if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0': + num, den = map(int, video_stream['r_frame_rate'].split('/')) + if den > 0: media_info['fps'] = num / den elif audio_stream: media_info['media_type'] = 'audio' media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0))) @@ -1724,20 +1764,37 @@ def _get_media_properties(self, file_path): print(f"Failed to probe file {os.path.basename(file_path)}: {e}") return None - def _set_project_properties_from_clip(self, source_path): + def _update_project_properties_from_clip(self, source_path): try: - probe = ffmpeg.probe(source_path) - video_stream = next((s for s in probe['streams'] if s['codec_type'] == 'video'), None) - if video_stream: - self.project_width = int(video_stream['width']); self.project_height = int(video_stream['height']) - if 'r_frame_rate' in video_stream and video_stream['r_frame_rate'] != '0/0': - num, den = map(int, video_stream['r_frame_rate'].split('/')) - if den > 0: - self.project_fps = num / den - self.timeline_widget.set_project_fps(self.project_fps) - print(f"Project properties set: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS") - return True - except Exception as e: print(f"Could not probe for project properties: {e}") + media_info = self._get_media_properties(source_path) + if not media_info or media_info['media_type'] not in ['video', 'image']: + return False + + new_w = media_info.get('width') + new_h = media_info.get('height') + new_fps = media_info.get('fps') + + if not new_w or not new_h: + return False + + is_first_video = not any(c.media_type in ['video', 'image'] for c in self.timeline.clips if c.source_path != source_path) + + if is_first_video: + self.project_width = new_w + self.project_height = new_h + if new_fps: self.project_fps = new_fps + self.timeline_widget.set_project_fps(self.project_fps) + print(f"Project properties set from first clip: {self.project_width}x{self.project_height} @ {self.project_fps:.2f} FPS") + else: + current_area = self.project_width * self.project_height + new_area = new_w * new_h + if new_area > current_area: + self.project_width = new_w + self.project_height = new_h + print(f"Project resolution updated to: {self.project_width}x{self.project_height}") + return True + except Exception as e: + print(f"Could not probe for project properties: {e}") return False def _probe_for_drag(self, file_path): @@ -1774,9 +1831,8 @@ def get_frame_data_at_time(self, time_sec): return (None, 0, 0) def get_frame_at_time(self, time_sec): - black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) clip_at_time = self._get_topmost_video_clip_at(time_sec) - if not clip_at_time: return black_pixmap + if not clip_at_time: return None try: w, h = self.project_width, self.project_height if clip_at_time.media_type == 'image': @@ -1797,16 +1853,36 @@ def get_frame_at_time(self, time_sec): image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888) return QPixmap.fromImage(image) - except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return black_pixmap + except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return None def seek_preview(self, time_sec): self._stop_playback_stream() self.timeline_widget.playhead_pos_sec = time_sec self.timeline_widget.update() - frame_pixmap = self.get_frame_at_time(time_sec) - if frame_pixmap: - scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) + self.current_preview_pixmap = self.get_frame_at_time(time_sec) + self._update_preview_display() + + def _update_preview_display(self): + pixmap_to_show = self.current_preview_pixmap + if not pixmap_to_show: + pixmap_to_show = QPixmap(self.project_width, self.project_height) + pixmap_to_show.fill(QColor("black")) + + if self.scale_to_fit: + # When fitting, we want the label to resize with the scroll area + self.preview_scroll_area.setWidgetResizable(True) + scaled_pixmap = pixmap_to_show.scaled( + self.preview_scroll_area.viewport().size(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + ) self.preview_widget.setPixmap(scaled_pixmap) + else: + # For 1:1, the label takes the size of the pixmap, and the scroll area handles overflow + self.preview_scroll_area.setWidgetResizable(False) + self.preview_widget.setPixmap(pixmap_to_show) + self.preview_widget.adjustSize() + def toggle_playback(self): if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play") @@ -1835,19 +1911,16 @@ def advance_playback_frame(self): if not clip_at_new_time: self._stop_playback_stream() - black_pixmap = QPixmap(self.project_width, self.project_height); black_pixmap.fill(QColor("black")) - scaled_pixmap = black_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - self.preview_widget.setPixmap(scaled_pixmap) + self.current_preview_pixmap = None + self._update_preview_display() return if clip_at_new_time.media_type == 'image': if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: self._stop_playback_stream() self.playback_clip = clip_at_new_time - frame_pixmap = self.get_frame_at_time(new_time) - if frame_pixmap: - scaled_pixmap = frame_pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - self.preview_widget.setPixmap(scaled_pixmap) + self.current_preview_pixmap = self.get_frame_at_time(new_time) + self._update_preview_display() return if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: @@ -1858,9 +1931,8 @@ def advance_playback_frame(self): frame_bytes = self.playback_process.stdout.read(frame_size) if len(frame_bytes) == frame_size: image = QImage(frame_bytes, self.project_width, self.project_height, QImage.Format.Format_RGB888) - pixmap = QPixmap.fromImage(image) - scaled_pixmap = pixmap.scaled(self.preview_widget.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation) - self.preview_widget.setPixmap(scaled_pixmap) + self.current_preview_pixmap = QPixmap.fromImage(image) + self._update_preview_display() else: self._stop_playback_stream() @@ -1892,13 +1964,15 @@ def _apply_loaded_settings(self): visibility_settings = self.settings.get("window_visibility", {}) for key, data in self.managed_widgets.items(): is_visible = visibility_settings.get(key, False if key == 'project_media' else True) - if data['widget'] is not self.preview_widget: + if data['widget'] is not self.preview_scroll_area: data['widget'].setVisible(is_visible) if data['action']: data['action'].setChecked(is_visible) splitter_state = self.settings.get("splitter_state") if splitter_state: self.splitter.restoreState(QByteArray.fromHex(splitter_state.encode('ascii'))) - def on_splitter_moved(self, pos, index): self.splitter_save_timer.start(500) + def on_splitter_moved(self, pos, index): + self.splitter_save_timer.start(500) + self._update_preview_display() def toggle_widget_visibility(self, key, checked): if self.is_shutting_down: return @@ -1916,6 +1990,8 @@ def new_project(self): self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list() self.current_project_path = None; self.stop_playback() self.project_fps = 25.0 + self.project_width = 1280 + self.project_height = 720 self.timeline_widget.set_project_fps(self.project_fps) self.timeline_widget.update() self.undo_stack = UndoStack() @@ -1929,7 +2005,13 @@ def save_project_as(self): project_data = { "media_pool": self.media_pool, "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips], - "settings": {"num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks} + "settings": { + "num_video_tracks": self.timeline.num_video_tracks, + "num_audio_tracks": self.timeline.num_audio_tracks, + "project_width": self.project_width, + "project_height": self.project_height, + "project_fps": self.project_fps + } } try: with open(path, "w") as f: json.dump(project_data, f, indent=4) @@ -1946,6 +2028,14 @@ def _load_project_from_path(self, path): with open(path, "r") as f: project_data = json.load(f) self.new_project() + project_settings = project_data.get("settings", {}) + self.timeline.num_video_tracks = project_settings.get("num_video_tracks", 1) + self.timeline.num_audio_tracks = project_settings.get("num_audio_tracks", 1) + self.project_width = project_settings.get("project_width", 1280) + self.project_height = project_settings.get("project_height", 720) + self.project_fps = project_settings.get("project_fps", 25.0) + self.timeline_widget.set_project_fps(self.project_fps) + self.media_pool = project_data.get("media_pool", []) for p in self.media_pool: self._add_media_to_pool(p) @@ -1962,14 +2052,7 @@ def _load_project_from_path(self, path): self.timeline.add_clip(TimelineClip(**clip_data)) - project_settings = project_data.get("settings", {}) - self.timeline.num_video_tracks = project_settings.get("num_video_tracks", 1) - self.timeline.num_audio_tracks = project_settings.get("num_audio_tracks", 1) - self.current_project_path = path - video_clips = [c for c in self.timeline.clips if c.media_type == 'video'] - if video_clips: - self._set_project_properties_from_clip(video_clips[0].source_path) self.prune_empty_tracks() self.timeline_widget.update(); self.stop_playback() self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.") @@ -2030,20 +2113,10 @@ def _add_media_files_to_project(self, file_paths): return [] self.media_dock.show() - added_files = [] - first_video_added = any(c.media_type == 'video' for c in self.timeline.clips) for file_path in file_paths: - if not self.timeline.clips and not self.media_pool and not first_video_added: - ext = os.path.splitext(file_path)[1].lower() - if ext not in ['.png', '.jpg', '.jpeg', '.mp3', '.wav']: - if self._set_project_properties_from_clip(file_path): - first_video_added = True - else: - self.status_label.setText("Error: Could not determine video properties from file.") - continue - + self._update_project_properties_from_clip(file_path) if self._add_media_to_pool(file_path): added_files.append(file_path) @@ -2118,6 +2191,9 @@ def add_media_files(self): self._add_media_files_to_project(file_paths) def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None): + if media_type in ['video', 'image']: + self._update_project_properties_from_clip(source_path) + old_state = self._get_current_timeline_state() group_id = str(uuid.uuid4()) From e6fd0aa99fe56d2381fc71df997d7db23eab15e6 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 14:02:38 +1100 Subject: [PATCH 128/155] fix scroll centering, overlapping track labels --- videoeditor.py | 120 ++++++++++++++++++++++++++++--------------------- 1 file changed, 70 insertions(+), 50 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index 15e5f0d7d..a05ff8506 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -106,7 +106,10 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.timeline = timeline_model self.settings = settings self.playhead_pos_sec = 0.0 - self.scroll_area = None + self.view_start_sec = 0.0 + self.panning = False + self.pan_start_pos = QPoint() + self.pan_start_view_sec = 0.0 self.pixels_per_second = 50.0 self.max_pixels_per_second = 1.0 @@ -165,19 +168,19 @@ def set_project_fps(self, fps): self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second) self.update() - def sec_to_x(self, sec): return self.HEADER_WIDTH + int(sec * self.pixels_per_second) - def x_to_sec(self, x): return float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else 0.0 + def sec_to_x(self, sec): return self.HEADER_WIDTH + int((sec - self.view_start_sec) * self.pixels_per_second) + def x_to_sec(self, x): return self.view_start_sec + float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else self.view_start_sec def paintEvent(self, event): painter = QPainter(self) painter.setRenderHint(QPainter.RenderHint.Antialiasing) painter.fillRect(self.rect(), QColor("#333")) - h_offset = 0 - if self.scroll_area and self.scroll_area.horizontalScrollBar(): - h_offset = self.scroll_area.horizontalScrollBar().value() + self.draw_headers(painter) - self.draw_headers(painter, h_offset) + painter.save() + painter.setClipRect(self.HEADER_WIDTH, 0, self.width() - self.HEADER_WIDTH, self.height()) + self.draw_timescale(painter) self.draw_tracks_and_clips(painter) self.draw_selections(painter) @@ -201,17 +204,19 @@ def paintEvent(self, event): painter.drawRect(self.hover_preview_audio_rect) self.draw_playhead(painter) + + painter.restore() - total_width = self.sec_to_x(self.timeline.get_total_duration()) + 200 total_height = self.calculate_total_height() - self.setMinimumSize(max(self.parent().width(), total_width), total_height) + if self.minimumHeight() != total_height: + self.setMinimumHeight(total_height) def calculate_total_height(self): video_tracks_height = (self.timeline.num_video_tracks + 1) * self.TRACK_HEIGHT audio_tracks_height = (self.timeline.num_audio_tracks + 1) * self.TRACK_HEIGHT return self.TIMESCALE_HEIGHT + video_tracks_height + self.AUDIO_TRACKS_SEPARATOR_Y + audio_tracks_height + 20 - def draw_headers(self, painter, h_offset): + def draw_headers(self, painter): painter.save() painter.setPen(QColor("#AAA")) header_font = QFont("Arial", 9, QFont.Weight.Bold) @@ -220,9 +225,9 @@ def draw_headers(self, painter, h_offset): y_cursor = self.TIMESCALE_HEIGHT rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a")) - painter.drawRect(rect.translated(h_offset, 0)) - self.add_video_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) + painter.fillRect(rect, QColor("#3a3a3a")) + painter.drawRect(rect) + self.add_video_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) painter.setFont(button_font) painter.fillRect(self.add_video_track_btn_rect, QColor("#454")) painter.drawText(self.add_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)") @@ -232,13 +237,13 @@ def draw_headers(self, painter, h_offset): for i in range(self.timeline.num_video_tracks): track_number = self.timeline.num_video_tracks - i rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect.translated(h_offset, 0), QColor("#444")) - painter.drawRect(rect.translated(h_offset, 0)) + painter.fillRect(rect, QColor("#444")) + painter.drawRect(rect) painter.setFont(header_font) - painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Video {track_number}") + painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Video {track_number}") if track_number == self.timeline.num_video_tracks and self.timeline.num_video_tracks > 1: - self.remove_video_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20) + self.remove_video_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20) painter.setFont(button_font) painter.fillRect(self.remove_video_track_btn_rect, QColor("#833")) painter.drawText(self.remove_video_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-") @@ -250,22 +255,22 @@ def draw_headers(self, painter, h_offset): for i in range(self.timeline.num_audio_tracks): track_number = i + 1 rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect.translated(h_offset, 0), QColor("#444")) - painter.drawRect(rect.translated(h_offset, 0)) + painter.fillRect(rect, QColor("#444")) + painter.drawRect(rect) painter.setFont(header_font) - painter.drawText(rect.translated(h_offset, 0), Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}") + painter.drawText(rect, Qt.AlignmentFlag.AlignCenter, f"Audio {track_number}") if track_number == self.timeline.num_audio_tracks and self.timeline.num_audio_tracks > 1: - self.remove_audio_track_btn_rect = QRect(h_offset + rect.right() - 25, rect.top() + 5, 20, 20) + self.remove_audio_track_btn_rect = QRect(rect.right() - 25, rect.top() + 5, 20, 20) painter.setFont(button_font) painter.fillRect(self.remove_audio_track_btn_rect, QColor("#833")) painter.drawText(self.remove_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "-") y_cursor += self.TRACK_HEIGHT rect = QRect(0, y_cursor, self.HEADER_WIDTH, self.TRACK_HEIGHT) - painter.fillRect(rect.translated(h_offset, 0), QColor("#3a3a3a")) - painter.drawRect(rect.translated(h_offset, 0)) - self.add_audio_track_btn_rect = QRect(h_offset + rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) + painter.fillRect(rect, QColor("#3a3a3a")) + painter.drawRect(rect) + self.add_audio_track_btn_rect = QRect(rect.left() + 10, rect.top() + (rect.height() - 22)//2, self.HEADER_WIDTH - 20, 22) painter.setFont(button_font) painter.fillRect(self.add_audio_track_btn_rect, QColor("#454")) painter.drawText(self.add_audio_track_btn_rect, Qt.AlignmentFlag.AlignCenter, "Add Track (+)") @@ -476,39 +481,46 @@ def get_region_at_pos(self, pos: QPoint): return None def wheelEvent(self, event: QMouseEvent): - if not self.scroll_area or event.position().x() < self.HEADER_WIDTH: + if event.position().x() < self.HEADER_WIDTH: event.ignore() return - scrollbar = self.scroll_area.horizontalScrollBar() - mouse_x_abs = event.position().x() - - time_at_cursor = self.x_to_sec(mouse_x_abs) + mouse_x = event.position().x() + time_at_cursor = self.x_to_sec(mouse_x) delta = event.angleDelta().y() zoom_factor = 1.15 old_pps = self.pixels_per_second - if delta > 0: new_pps = old_pps * zoom_factor - else: new_pps = old_pps / zoom_factor - + if delta > 0: + new_pps = old_pps * zoom_factor + else: + new_pps = old_pps / zoom_factor + min_pps = 1 / (3600 * 10) new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second)) - + if abs(new_pps - old_pps) < 1e-9: return - + self.pixels_per_second = new_pps - self.update() - new_mouse_x_abs = self.sec_to_x(time_at_cursor) - shift_amount = new_mouse_x_abs - mouse_x_abs - new_scroll_value = scrollbar.value() + shift_amount - scrollbar.setValue(int(new_scroll_value)) + new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / self.pixels_per_second + self.view_start_sec = max(0.0, new_view_start_sec) + + self.update() event.accept() def mousePressEvent(self, event: QMouseEvent): - if event.pos().x() < self.HEADER_WIDTH + self.scroll_area.horizontalScrollBar().value(): + if event.button() == Qt.MouseButton.MiddleButton: + self.panning = True + self.pan_start_pos = event.pos() + self.pan_start_view_sec = self.view_start_sec + self.setCursor(Qt.CursorShape.ClosedHandCursor) + event.accept() + return + + if event.pos().x() < self.HEADER_WIDTH: if self.add_video_track_btn_rect.contains(event.pos()): self.add_track.emit('video') elif self.remove_video_track_btn_rect.contains(event.pos()): self.remove_track.emit('video') elif self.add_audio_track_btn_rect.contains(event.pos()): self.add_track.emit('audio') @@ -622,6 +634,14 @@ def mousePressEvent(self, event: QMouseEvent): self.update() def mouseMoveEvent(self, event: QMouseEvent): + if self.panning: + delta_x = event.pos().x() - self.pan_start_pos.x() + time_delta = delta_x / self.pixels_per_second + new_view_start = self.pan_start_view_sec - time_delta + self.view_start_sec = max(0.0, new_view_start) + self.update() + return + if self.resizing_selection_region: current_sec = max(0, self.x_to_sec(event.pos().x())) original_start, original_end = self.resize_selection_start_values @@ -842,6 +862,12 @@ def mouseMoveEvent(self, event: QMouseEvent): self.update() def mouseReleaseEvent(self, event: QMouseEvent): + if event.button() == Qt.MouseButton.MiddleButton and self.panning: + self.panning = False + self.unsetCursor() + event.accept() + return + if event.button() == Qt.MouseButton.LeftButton: if self.resizing_selection_region: self.resizing_selection_region = None @@ -1419,14 +1445,8 @@ def _setup_ui(self): self.splitter.addWidget(self.preview_scroll_area) self.timeline_widget = TimelineWidget(self.timeline, self.settings, self.project_fps, self) - self.timeline_scroll_area = QScrollArea() - self.timeline_widget.scroll_area = self.timeline_scroll_area - self.timeline_scroll_area.setWidgetResizable(False) - self.timeline_widget.setMinimumWidth(2000) - self.timeline_scroll_area.setWidget(self.timeline_widget) - self.timeline_scroll_area.setFrameShape(QFrame.Shape.NoFrame) - self.timeline_scroll_area.setMinimumHeight(250) - self.splitter.addWidget(self.timeline_scroll_area) + self.timeline_widget.setMinimumHeight(250) + self.splitter.addWidget(self.timeline_widget) self.splitter.setStretchFactor(0, 1) self.splitter.setStretchFactor(1, 0) @@ -1465,7 +1485,7 @@ def _setup_ui(self): self.managed_widgets = { 'preview': {'widget': self.preview_scroll_area, 'name': 'Video Preview', 'action': None}, - 'timeline': {'widget': self.timeline_scroll_area, 'name': 'Timeline', 'action': None}, + 'timeline': {'widget': self.timeline_widget, 'name': 'Timeline', 'action': None}, 'project_media': {'widget': self.media_dock, 'name': 'Project Media', 'action': None} } self.plugin_menu_actions = {} @@ -1676,7 +1696,7 @@ def _create_menu_bar(self): self.windows_menu = menu_bar.addMenu("&Windows") for key, data in self.managed_widgets.items(): - if data['widget'] is self.preview_scroll_area: continue + if data['widget'] is self.preview_scroll_area or data['widget'] is self.timeline_widget: continue action = QAction(data['name'], self, checkable=True) if hasattr(data['widget'], 'visibilityChanged'): action.toggled.connect(data['widget'].setVisible) From 46d2c6c4b45badcb10870eecdb7208dfeaee9c3c Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 14:09:17 +1100 Subject: [PATCH 129/155] fix 1st frame being rendered blank when adding video to timeline --- videoeditor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/videoeditor.py b/videoeditor.py index a05ff8506..885a58bef 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -1542,6 +1542,7 @@ def _toggle_scale_to_fit(self, checked): def on_timeline_changed_by_undo(self): self.prune_empty_tracks() self.timeline_widget.update() + self.seek_preview(self.timeline_widget.playhead_pos_sec) self.status_label.setText("Operation undone/redone.") def update_undo_redo_actions(self): From 0bd6a8d6c83ac2972c8eefd48475abf5f3c6a394 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 14:27:25 +1100 Subject: [PATCH 130/155] add zoom to 0 on timescale --- videoeditor.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index 885a58bef..a63d2ac99 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -481,13 +481,6 @@ def get_region_at_pos(self, pos: QPoint): return None def wheelEvent(self, event: QMouseEvent): - if event.position().x() < self.HEADER_WIDTH: - event.ignore() - return - - mouse_x = event.position().x() - time_at_cursor = self.x_to_sec(mouse_x) - delta = event.angleDelta().y() zoom_factor = 1.15 old_pps = self.pixels_per_second @@ -503,9 +496,16 @@ def wheelEvent(self, event: QMouseEvent): if abs(new_pps - old_pps) < 1e-9: return - self.pixels_per_second = new_pps + if event.position().x() < self.HEADER_WIDTH: + # Zoom centered on time 0. Keep screen position of time 0 constant. + new_view_start_sec = self.view_start_sec * (old_pps / new_pps) + else: + # Zoom centered on mouse cursor. + mouse_x = event.position().x() + time_at_cursor = self.x_to_sec(mouse_x) + new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps - new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / self.pixels_per_second + self.pixels_per_second = new_pps self.view_start_sec = max(0.0, new_view_start_sec) self.update() From 0ca8ba19cf7c40e275fca2e420c0cb18eb9a5c8f Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 14:33:07 +1100 Subject: [PATCH 131/155] fix track labels overlaying 0s --- videoeditor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index a63d2ac99..76a4ab495 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -349,7 +349,10 @@ def draw_ticks(interval, height): painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT) label = self._format_timecode(t_sec) label_width = font_metrics.horizontalAdvance(label) - painter.drawText(x - label_width // 2, self.TIMESCALE_HEIGHT - 14, label) + label_x = x - label_width // 2 + if label_x < self.HEADER_WIDTH: + label_x = self.HEADER_WIDTH + painter.drawText(label_x, self.TIMESCALE_HEIGHT - 14, label) painter.restore() @@ -497,10 +500,8 @@ def wheelEvent(self, event: QMouseEvent): return if event.position().x() < self.HEADER_WIDTH: - # Zoom centered on time 0. Keep screen position of time 0 constant. new_view_start_sec = self.view_start_sec * (old_pps / new_pps) else: - # Zoom centered on mouse cursor. mouse_x = event.position().x() time_at_cursor = self.x_to_sec(mouse_x) new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps From eef79e40fb20bd30f40271a2d51f28a0197527e8 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 14:53:22 +1100 Subject: [PATCH 132/155] fix add track button not working --- videoeditor.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index 76a4ab495..b4cdb2164 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -1626,12 +1626,18 @@ def add_track(self, track_type): self.timeline.num_video_tracks += 1 elif track_type == 'audio': self.timeline.num_audio_tracks += 1 + else: + return + new_state = self._get_current_timeline_state() - command = TimelineStateChangeCommand(f"Add {track_type.capitalize()} Track", self.timeline, *old_state, *new_state) + + self.undo_stack.blockSignals(True) self.undo_stack.push(command) - command.undo() - self.undo_stack.push(command) + self.undo_stack.blockSignals(False) + + self.update_undo_redo_actions() + self.timeline_widget.update() def remove_track(self, track_type): old_state = self._get_current_timeline_state() From 088940b0e526c1335b75b87a158ef52529fa7751 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 15:10:34 +1100 Subject: [PATCH 133/155] snap timehead to single frames --- videoeditor.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index b4cdb2164..6abb4091b 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -473,6 +473,13 @@ def y_to_track_info(self, y): return None + def _snap_time_if_needed(self, time_sec): + frame_duration = 1.0 / self.project_fps + if frame_duration > 0 and frame_duration * self.pixels_per_second > 4: + frame_number = round(time_sec / frame_duration) + return frame_number * frame_duration + return time_sec + def get_region_at_pos(self, pos: QPoint): if pos.y() <= self.TIMESCALE_HEIGHT or pos.x() <= self.HEADER_WIDTH: return None @@ -628,7 +635,8 @@ def mousePressEvent(self, event: QMouseEvent): self.selection_drag_start_sec = self.x_to_sec(event.pos().x()) self.selection_regions.append([self.selection_drag_start_sec, self.selection_drag_start_sec]) elif is_on_timescale: - self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x())) + time_sec = max(0, self.x_to_sec(event.pos().x())) + self.playhead_pos_sec = self._snap_time_if_needed(time_sec) self.playhead_moved.emit(self.playhead_pos_sec) self.dragging_playhead = True @@ -798,7 +806,8 @@ def mouseMoveEvent(self, event: QMouseEvent): return if self.dragging_playhead: - self.playhead_pos_sec = max(0, self.x_to_sec(event.pos().x())) + time_sec = max(0, self.x_to_sec(event.pos().x())) + self.playhead_pos_sec = self._snap_time_if_needed(time_sec) self.playhead_moved.emit(self.playhead_pos_sec) self.update() elif self.dragging_clip: From cc553e2e5a26f0e2fe35c3bc8645a5dbe774eb43 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 15:49:55 +1100 Subject: [PATCH 134/155] automatically set resolution from project resolution or clip resolution depending on mode --- plugins/wan2gp/main.py | 82 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 7 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 0b925227c..568c31fb9 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -1339,6 +1339,55 @@ def _on_resolution_group_changed(self): self.widgets['resolution'].setCurrentIndex(self.widgets['resolution'].findData(last_resolution)) self.widgets['resolution'].blockSignals(False) + def set_resolution_from_target(self, target_w, target_h): + if not self.full_resolution_choices: + print("Resolution choices not available for AI resolution matching.") + return + + target_pixels = target_w * target_h + target_ar = target_w / target_h if target_h > 0 else 1.0 + + best_res_value = None + min_dist = float('inf') + + for label, res_value in self.full_resolution_choices: + try: + w_str, h_str = res_value.split('x') + w, h = int(w_str), int(h_str) + except (ValueError, AttributeError): + continue + + pixels = w * h + ar = w / h if h > 0 else 1.0 + + pixel_dist = abs(target_pixels - pixels) / target_pixels + ar_dist = abs(target_ar - ar) / target_ar + + dist = pixel_dist * 0.8 + ar_dist * 0.2 + + if dist < min_dist: + min_dist = dist + best_res_value = res_value + + if best_res_value: + best_group = self.wgp.categorize_resolution(best_res_value) + + group_combo = self.widgets['resolution_group'] + group_combo.blockSignals(True) + group_index = group_combo.findText(best_group) + if group_index != -1: + group_combo.setCurrentIndex(group_index) + group_combo.blockSignals(False) + + self._on_resolution_group_changed() + + res_combo = self.widgets['resolution'] + res_index = res_combo.findData(best_res_value) + if res_index != -1: + res_combo.setCurrentIndex(res_index) + else: + print(f"Warning: Could not find resolution '{best_res_value}' in dropdown after group change.") + def collect_inputs(self): full_inputs = self.wgp.get_current_model_settings(self.state).copy() full_inputs['lset_name'] = "" @@ -1663,6 +1712,9 @@ def setup_generator_for_region(self, region, on_new_track=False): if not start_data or not end_data: QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.") return + + self.client_widget.set_resolution_from_target(w, h) + try: self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_") self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png") @@ -1675,13 +1727,25 @@ def setup_generator_for_region(self, region, on_new_track=False): model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) - - self.client_widget.widgets['video_length'].setValue(video_length_frames) - self.client_widget.widgets['mode_s'].setChecked(True) - self.client_widget.widgets['image_end_checkbox'].setChecked(True) - self.client_widget.widgets['image_start'].setText(self.start_frame_path) - self.client_widget.widgets['image_end'].setText(self.end_frame_path) - + widgets = self.client_widget.widgets + widgets['mode_s'].blockSignals(True) + widgets['mode_t'].blockSignals(True) + widgets['mode_v'].blockSignals(True) + widgets['mode_l'].blockSignals(True) + widgets['image_end_checkbox'].blockSignals(True) + + widgets['video_length'].setValue(video_length_frames) + widgets['mode_s'].setChecked(True) + widgets['image_end_checkbox'].setChecked(True) + widgets['image_start'].setText(self.start_frame_path) + widgets['image_end'].setText(self.end_frame_path) + + widgets['mode_s'].blockSignals(False) + widgets['mode_t'].blockSignals(False) + widgets['mode_v'].blockSignals(False) + widgets['mode_l'].blockSignals(False) + widgets['image_end_checkbox'].blockSignals(False) + self.client_widget._update_input_visibility() except Exception as e: @@ -1707,6 +1771,10 @@ def setup_creator_for_region(self, region, on_new_track=False): else: print(f"Warning: Default model '{model_to_set}' not found for AI Creator. Using current model.") + target_w = self.app.project_width + target_h = self.app.project_height + self.client_widget.set_resolution_from_target(target_w, target_h) + start_sec, end_sec = region duration_sec = end_sec - start_sec wgp = self.client_widget.wgp From 80ecba27c962b8e5cda8f7a562eb5fb061f934aa Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 16:09:54 +1100 Subject: [PATCH 135/155] add join to start frame with ai and join to end frame with ai modes --- plugins/wan2gp/main.py | 156 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 146 insertions(+), 10 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 568c31fb9..2c38d30fa 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -1681,11 +1681,23 @@ def on_timeline_context_menu(self, menu, event): start_sec, end_sec = region start_data, _, _ = self.app.get_frame_data_at_time(start_sec) end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + if start_data and end_data: join_action = menu.addAction("Join Frames With AI") join_action.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=False)) join_action_new_track = menu.addAction("Join Frames With AI (New Track)") join_action_new_track.triggered.connect(lambda: self.setup_generator_for_region(region, on_new_track=True)) + elif start_data: + from_start_action = menu.addAction("Generate from Start Frame with AI") + from_start_action.triggered.connect(lambda: self.setup_generator_from_start(region, on_new_track=False)) + from_start_action_new_track = menu.addAction("Generate from Start Frame with AI (New Track)") + from_start_action_new_track.triggered.connect(lambda: self.setup_generator_from_start(region, on_new_track=True)) + elif end_data: + to_end_action = menu.addAction("Generate to End Frame with AI") + to_end_action.triggered.connect(lambda: self.setup_generator_to_end(region, on_new_track=False)) + to_end_action_new_track = menu.addAction("Generate to End Frame with AI (New Track)") + to_end_action_new_track.triggered.connect(lambda: self.setup_generator_to_end(region, on_new_track=True)) + create_action = menu.addAction("Create Frames With AI") create_action.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=False)) create_action_new_track = menu.addAction("Create Frames With AI (New Track)") @@ -1728,11 +1740,9 @@ def setup_generator_for_region(self, region, on_new_track=False): fps = wgp.get_model_fps(model_type) video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) widgets = self.client_widget.widgets - widgets['mode_s'].blockSignals(True) - widgets['mode_t'].blockSignals(True) - widgets['mode_v'].blockSignals(True) - widgets['mode_l'].blockSignals(True) - widgets['image_end_checkbox'].blockSignals(True) + + for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']: + widgets[w_name].blockSignals(True) widgets['video_length'].setValue(video_length_frames) widgets['mode_s'].setChecked(True) @@ -1740,11 +1750,8 @@ def setup_generator_for_region(self, region, on_new_track=False): widgets['image_start'].setText(self.start_frame_path) widgets['image_end'].setText(self.end_frame_path) - widgets['mode_s'].blockSignals(False) - widgets['mode_t'].blockSignals(False) - widgets['mode_v'].blockSignals(False) - widgets['mode_l'].blockSignals(False) - widgets['image_end_checkbox'].blockSignals(False) + for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']: + widgets[w_name].blockSignals(False) self.client_widget._update_input_visibility() @@ -1756,6 +1763,135 @@ def setup_generator_for_region(self, region, on_new_track=False): self.dock_widget.show() self.dock_widget.raise_() + def setup_generator_from_start(self, region, on_new_track=False): + self._reset_state() + self.active_region = region + self.insert_on_new_track = on_new_track + + model_to_set = 'i2v_2_2' + dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types + _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + if any(model_to_set == m[1] for m in all_models): + if self.client_widget.state.get('model_type') != model_to_set: + self.client_widget.update_model_dropdowns(model_to_set) + self.client_widget._on_model_changed() + else: + print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") + + start_sec, end_sec = region + start_data, w, h = self.app.get_frame_data_at_time(start_sec) + if not start_data: + QMessageBox.warning(self.app, "Frame Error", "Could not extract start frame.") + return + + self.client_widget.set_resolution_from_target(w, h) + + try: + self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_") + self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png") + QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path) + + duration_sec = end_sec - start_sec + wgp = self.client_widget.wgp + model_type = self.client_widget.state['model_type'] + fps = wgp.get_model_fps(model_type) + video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + widgets = self.client_widget.widgets + + widgets['video_length'].setValue(video_length_frames) + + for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']: + widgets[w_name].blockSignals(True) + + widgets['mode_s'].setChecked(True) + widgets['image_end_checkbox'].setChecked(False) + widgets['image_start'].setText(self.start_frame_path) + widgets['image_end'].clear() + + for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']: + widgets[w_name].blockSignals(False) + + self.client_widget._update_input_visibility() + + except Exception as e: + QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame image: {e}") + self._cleanup_temp_dir() + return + + self.app.status_label.setText(f"Ready to generate from frame at {start_sec:.2f}s.") + self.dock_widget.show() + self.dock_widget.raise_() + + def setup_generator_to_end(self, region, on_new_track=False): + self._reset_state() + self.active_region = region + self.insert_on_new_track = on_new_track + + model_to_set = 'i2v_2_2' + dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types + _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + if any(model_to_set == m[1] for m in all_models): + if self.client_widget.state.get('model_type') != model_to_set: + self.client_widget.update_model_dropdowns(model_to_set) + self.client_widget._on_model_changed() + else: + print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") + + start_sec, end_sec = region + end_data, w, h = self.app.get_frame_data_at_time(end_sec) + if not end_data: + QMessageBox.warning(self.app, "Frame Error", "Could not extract end frame.") + return + + self.client_widget.set_resolution_from_target(w, h) + + try: + self.temp_dir = tempfile.mkdtemp(prefix="wgp_plugin_") + self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png") + QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) + + duration_sec = end_sec - start_sec + wgp = self.client_widget.wgp + model_type = self.client_widget.state['model_type'] + fps = wgp.get_model_fps(model_type) + video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + widgets = self.client_widget.widgets + + widgets['video_length'].setValue(video_length_frames) + + model_def = self.client_widget.wgp.get_model_def(self.client_widget.state['model_type']) + allowed_modes = model_def.get("image_prompt_types_allowed", "") + + if "E" not in allowed_modes: + QMessageBox.warning(self.app, "Model Incompatible", "The current model does not support generating to an end frame.") + return + + if "S" not in allowed_modes: + QMessageBox.warning(self.app, "Model Incompatible", "The current model supports end frames, but not in a way compatible with this UI feature (missing 'Start with Image' mode).") + return + + for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']: + widgets[w_name].blockSignals(True) + + widgets['mode_s'].setChecked(True) + widgets['image_end_checkbox'].setChecked(True) + widgets['image_start'].clear() + widgets['image_end'].setText(self.end_frame_path) + + for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']: + widgets[w_name].blockSignals(False) + + self.client_widget._update_input_visibility() + + except Exception as e: + QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame image: {e}") + self._cleanup_temp_dir() + return + + self.app.status_label.setText(f"Ready to generate to frame at {end_sec:.2f}s.") + self.dock_widget.show() + self.dock_widget.raise_() + def setup_creator_for_region(self, region, on_new_track=False): self._reset_state() self.active_region = region From bbe9df1eccd10c387e3e536a7aa09cc5fbb7460d Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 16:36:41 +1100 Subject: [PATCH 136/155] make sure advanced options tabs dont need scroll --- plugins/wan2gp/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 2c38d30fa..12a4ede6e 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -423,6 +423,7 @@ def setup_generator_tab(self): self.tabs.addTab(gen_tab, "Video Generator") gen_layout = QHBoxLayout(gen_tab) left_panel = QWidget() + left_panel.setMinimumWidth(628) left_layout = QVBoxLayout(left_panel) gen_layout.addWidget(left_panel, 1) right_panel = QWidget() From f2324d42424d5a584388a5978b8da43dd0fa29f1 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 16:44:29 +1100 Subject: [PATCH 137/155] allow sliders to be text editable --- plugins/wan2gp/main.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 12a4ede6e..41edebd80 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -296,11 +296,35 @@ def _create_slider_with_label(self, name, min_val, max_val, initial_val, scale=1 slider = self.create_widget(QSlider, name, Qt.Orientation.Horizontal) slider.setRange(min_val, max_val) slider.setValue(int(initial_val * scale)) - value_label = self.create_widget(QLabel, f"{name}_label", f"{initial_val:.{precision}f}") - value_label.setMinimumWidth(50) - slider.valueChanged.connect(lambda v, lbl=value_label, s=scale, p=precision: lbl.setText(f"{v/s:.{p}f}")) + + value_edit = self.create_widget(QLineEdit, f"{name}_label", f"{initial_val:.{precision}f}") + value_edit.setFixedWidth(60) + value_edit.setAlignment(Qt.AlignmentFlag.AlignCenter) + + def sync_slider_from_text(): + try: + text_value = float(value_edit.text()) + slider_value = int(round(text_value * scale)) + + slider.blockSignals(True) + slider.setValue(slider_value) + slider.blockSignals(False) + + actual_slider_value = slider.value() + if actual_slider_value != slider_value: + value_edit.setText(f"{actual_slider_value / scale:.{precision}f}") + except (ValueError, TypeError): + value_edit.setText(f"{slider.value() / scale:.{precision}f}") + + def sync_text_from_slider(value): + if not value_edit.hasFocus(): + value_edit.setText(f"{value / scale:.{precision}f}") + + value_edit.editingFinished.connect(sync_slider_from_text) + slider.valueChanged.connect(sync_text_from_slider) + hbox.addWidget(slider) - hbox.addWidget(value_label) + hbox.addWidget(value_edit) return container def _create_file_input(self, name, label_text): From 2a867c2bfb54c3a52c2a976d33d2bf12a35fc723 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 16:53:26 +1100 Subject: [PATCH 138/155] save selected regions to project --- videoeditor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/videoeditor.py b/videoeditor.py index 6abb4091b..ce49757b1 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -2030,6 +2030,7 @@ def new_project(self): self.project_width = 1280 self.project_height = 720 self.timeline_widget.set_project_fps(self.project_fps) + self.timeline_widget.clear_all_regions() self.timeline_widget.update() self.undo_stack = UndoStack() self.undo_stack.history_changed.connect(self.update_undo_redo_actions) @@ -2042,6 +2043,7 @@ def save_project_as(self): project_data = { "media_pool": self.media_pool, "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips], + "selection_regions": self.timeline_widget.selection_regions, "settings": { "num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks, @@ -2073,6 +2075,8 @@ def _load_project_from_path(self, path): self.project_fps = project_settings.get("project_fps", 25.0) self.timeline_widget.set_project_fps(self.project_fps) + self.timeline_widget.selection_regions = project_data.get("selection_regions", []) + self.media_pool = project_data.get("media_pool", []) for p in self.media_pool: self._add_media_to_pool(p) From 4d79760e83e9579cd3eb2ec65c493c62a42fd02d Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 17:30:20 +1100 Subject: [PATCH 139/155] set minheight --- plugins/wan2gp/main.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 41edebd80..623d4afa7 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -469,11 +469,13 @@ def setup_generator_tab(self): model_layout.addWidget(self.widgets['model_choice'], 3) options_layout.addLayout(model_layout) options_layout.addWidget(QLabel("Prompt:")) - self.create_widget(QTextEdit, 'prompt').setMinimumHeight(100) - options_layout.addWidget(self.widgets['prompt']) + prompt_edit = self.create_widget(QTextEdit, 'prompt') + prompt_edit.setMaximumHeight(prompt_edit.fontMetrics().lineSpacing() * 5 + 15) + options_layout.addWidget(prompt_edit) options_layout.addWidget(QLabel("Negative Prompt:")) - self.create_widget(QTextEdit, 'negative_prompt').setMinimumHeight(60) - options_layout.addWidget(self.widgets['negative_prompt']) + neg_prompt_edit = self.create_widget(QTextEdit, 'negative_prompt') + neg_prompt_edit.setMaximumHeight(neg_prompt_edit.fontMetrics().lineSpacing() * 3 + 15) + options_layout.addWidget(neg_prompt_edit) basic_group = QGroupBox("Basic Options") basic_layout = QFormLayout(basic_group) res_container = QWidget() @@ -528,7 +530,7 @@ def setup_generator_tab(self): self._setup_adv_tab_sliding_window(advanced_tabs) self._setup_adv_tab_misc(advanced_tabs) options_layout.addWidget(self.advanced_group) - + scroll_area.setMinimumHeight(options_widget.sizeHint().height()-260) btn_layout = QHBoxLayout() self.generate_btn = self.create_widget(QPushButton, 'generate_btn', "Generate") self.add_to_queue_btn = self.create_widget(QPushButton, 'add_to_queue_btn', "Add to Queue") From a8fd923e49d3e246e6c050b483061a35a2f5f160 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 19:09:10 +1100 Subject: [PATCH 140/155] add entire ffmpeg library of containers and formats to export options --- videoeditor.py | 470 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 430 insertions(+), 40 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index ce49757b1..b6673c879 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -10,7 +10,9 @@ from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QScrollArea, QFrame, QProgressBar, QDialog, - QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox) + QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, + QListWidget, QListWidgetItem, QMessageBox, QComboBox, + QFormLayout, QGroupBox, QLineEdit) from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction, QPixmap, QImage, QDrag, QCursor, QKeyEvent) from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread, @@ -19,6 +21,165 @@ from plugins import PluginManager, ManagePluginsDialog from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand +CONTAINER_PRESETS = { + 'mp4': { + 'vcodec': 'libx264', 'acodec': 'aac', + 'allowed_vcodecs': ['libx264', 'libx265', 'mpeg4'], + 'allowed_acodecs': ['aac', 'libmp3lame'], + 'v_bitrate': '5M', 'a_bitrate': '192k' + }, + 'matroska': { + 'vcodec': 'libx264', 'acodec': 'aac', + 'allowed_vcodecs': ['libx264', 'libx265', 'libvpx-vp9'], + 'allowed_acodecs': ['aac', 'libopus', 'libvorbis', 'flac'], + 'v_bitrate': '5M', 'a_bitrate': '192k' + }, + 'mov': { + 'vcodec': 'libx264', 'acodec': 'aac', + 'allowed_vcodecs': ['libx264', 'prores_ks', 'mpeg4'], + 'allowed_acodecs': ['aac', 'pcm_s16le'], + 'v_bitrate': '8M', 'a_bitrate': '256k' + }, + 'avi': { + 'vcodec': 'mpeg4', 'acodec': 'libmp3lame', + 'allowed_vcodecs': ['mpeg4', 'msmpeg4'], + 'allowed_acodecs': ['libmp3lame'], + 'v_bitrate': '5M', 'a_bitrate': '192k' + }, + 'webm': { + 'vcodec': 'libvpx-vp9', 'acodec': 'libopus', + 'allowed_vcodecs': ['libvpx-vp9'], + 'allowed_acodecs': ['libopus', 'libvorbis'], + 'v_bitrate': '4M', 'a_bitrate': '192k' + }, + 'wav': { + 'vcodec': None, 'acodec': 'pcm_s16le', + 'allowed_vcodecs': [], 'allowed_acodecs': ['pcm_s16le', 'pcm_s24le'], + 'v_bitrate': None, 'a_bitrate': None + }, + 'mp3': { + 'vcodec': None, 'acodec': 'libmp3lame', + 'allowed_vcodecs': [], 'allowed_acodecs': ['libmp3lame'], + 'v_bitrate': None, 'a_bitrate': '192k' + }, + 'flac': { + 'vcodec': None, 'acodec': 'flac', + 'allowed_vcodecs': [], 'allowed_acodecs': ['flac'], + 'v_bitrate': None, 'a_bitrate': None + }, + 'gif': { + 'vcodec': 'gif', 'acodec': None, + 'allowed_vcodecs': ['gif'], 'allowed_acodecs': [], + 'v_bitrate': None, 'a_bitrate': None + }, + 'oga': { # Using oga for ogg audio + 'vcodec': None, 'acodec': 'libvorbis', + 'allowed_vcodecs': [], 'allowed_acodecs': ['libvorbis', 'libopus'], + 'v_bitrate': None, 'a_bitrate': '192k' + } +} + +_cached_formats = None +_cached_video_codecs = None +_cached_audio_codecs = None + +def run_ffmpeg_command(args): + try: + startupinfo = None + if hasattr(subprocess, 'STARTUPINFO'): + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + startupinfo.wShowWindow = subprocess.SW_HIDE + + result = subprocess.run( + ['ffmpeg'] + args, + capture_output=True, text=True, encoding='utf-8', + errors='ignore', startupinfo=startupinfo + ) + if result.returncode != 0 and "Unrecognized option" not in result.stderr: + print(f"FFmpeg command failed: {' '.join(args)}\n{result.stderr}") + return "" + return result.stdout + except FileNotFoundError: + print("Error: ffmpeg not found. Please ensure it is in your system's PATH.") + return None + except Exception as e: + print(f"An error occurred while running ffmpeg: {e}") + return None + +def get_available_formats(): + global _cached_formats + if _cached_formats is not None: + return _cached_formats + + output = run_ffmpeg_command(['-formats']) + if not output: + _cached_formats = {} + return {} + + formats = {} + lines = output.split('\n') + header_found = False + for line in lines: + if "---" in line: + header_found = True + continue + if not header_found or not line.strip(): + continue + + if line[2] == 'E': + parts = line[4:].strip().split(None, 1) + if len(parts) == 2: + names, description = parts + primary_name = names.split(',')[0].strip() + formats[primary_name] = description.strip() + + _cached_formats = dict(sorted(formats.items())) + return _cached_formats + +def get_available_codecs(codec_type='video'): + global _cached_video_codecs, _cached_audio_codecs + + if codec_type == 'video' and _cached_video_codecs is not None: return _cached_video_codecs + if codec_type == 'audio' and _cached_audio_codecs is not None: return _cached_audio_codecs + + output = run_ffmpeg_command(['-encoders']) + if not output: + if codec_type == 'video': _cached_video_codecs = {} + else: _cached_audio_codecs = {} + return {} + + video_codecs = {} + audio_codecs = {} + lines = output.split('\n') + + header_found = False + for line in lines: + if "------" in line: + header_found = True + continue + + if not header_found or not line.strip(): + continue + + parts = line.strip().split(None, 2) + if len(parts) < 3: + continue + + flags, name, description = parts + type_flag = flags[0] + clean_description = re.sub(r'\s*\(codec .*\)$', '', description).strip() + + if type_flag == 'V': + video_codecs[name] = clean_description + elif type_flag == 'A': + audio_codecs[name] = clean_description + + _cached_video_codecs = dict(sorted(video_codecs.items())) + _cached_audio_codecs = dict(sorted(audio_codecs.items())) + + return _cached_video_codecs if codec_type == 'video' else _cached_audio_codecs + class TimelineClip: def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, media_type, group_id): self.id = str(uuid.uuid4()) @@ -1259,19 +1420,235 @@ class SettingsDialog(QDialog): def __init__(self, parent_settings, parent=None): super().__init__(parent) self.setWindowTitle("Settings") - self.setMinimumWidth(350) + self.setMinimumWidth(450) layout = QVBoxLayout(self) self.confirm_on_exit_checkbox = QCheckBox("Confirm before exiting") self.confirm_on_exit_checkbox.setChecked(parent_settings.get("confirm_on_exit", True)) layout.addWidget(self.confirm_on_exit_checkbox) + + export_path_group = QGroupBox("Default Export Path (for new projects)") + export_path_layout = QHBoxLayout() + self.default_export_path_edit = QLineEdit() + self.default_export_path_edit.setPlaceholderText("Optional: e.g., C:/Users/YourUser/Videos/Exports") + self.default_export_path_edit.setText(parent_settings.get("default_export_path", "")) + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self.browse_default_export_path) + export_path_layout.addWidget(self.default_export_path_edit) + export_path_layout.addWidget(browse_button) + export_path_group.setLayout(export_path_layout) + layout.addWidget(export_path_group) + button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) button_box.accepted.connect(self.accept) button_box.rejected.connect(self.reject) layout.addWidget(button_box) + def browse_default_export_path(self): + path = QFileDialog.getExistingDirectory(self, "Select Default Export Folder", self.default_export_path_edit.text()) + if path: + self.default_export_path_edit.setText(path) + def get_settings(self): return { - "confirm_on_exit": self.confirm_on_exit_checkbox.isChecked() + "confirm_on_exit": self.confirm_on_exit_checkbox.isChecked(), + "default_export_path": self.default_export_path_edit.text(), + } + +class ExportDialog(QDialog): + def __init__(self, default_path, parent=None): + super().__init__(parent) + self.setWindowTitle("Export Settings") + self.setMinimumWidth(550) + + self.video_bitrate_options = ["500k", "1M", "2.5M", "5M", "8M", "15M", "Custom..."] + self.audio_bitrate_options = ["96k", "128k", "192k", "256k", "320k", "Custom..."] + self.display_ext_map = {'matroska': 'mkv', 'oga': 'ogg'} + + self.layout = QVBoxLayout(self) + self.formats = get_available_formats() + self.video_codecs = get_available_codecs('video') + self.audio_codecs = get_available_codecs('audio') + + self._setup_ui() + + self.path_edit.setText(default_path) + self.on_advanced_toggled(False) + + def _setup_ui(self): + output_group = QGroupBox("Output File") + output_layout = QHBoxLayout() + self.path_edit = QLineEdit() + browse_button = QPushButton("Browse...") + browse_button.clicked.connect(self.browse_output_path) + output_layout.addWidget(self.path_edit) + output_layout.addWidget(browse_button) + output_group.setLayout(output_layout) + self.layout.addWidget(output_group) + + self.container_combo = QComboBox() + self.container_combo.currentIndexChanged.connect(self.on_container_changed) + + self.advanced_formats_checkbox = QCheckBox("Show Advanced Options") + self.advanced_formats_checkbox.toggled.connect(self.on_advanced_toggled) + + container_layout = QFormLayout() + container_layout.addRow("Format Preset:", self.container_combo) + container_layout.addRow(self.advanced_formats_checkbox) + self.layout.addLayout(container_layout) + + self.video_group = QGroupBox("Video Settings") + video_layout = QFormLayout() + self.video_codec_combo = QComboBox() + video_layout.addRow("Video Codec:", self.video_codec_combo) + self.v_bitrate_combo = QComboBox() + self.v_bitrate_combo.addItems(self.video_bitrate_options) + self.v_bitrate_custom_edit = QLineEdit() + self.v_bitrate_custom_edit.setPlaceholderText("e.g., 6500k") + self.v_bitrate_custom_edit.hide() + self.v_bitrate_combo.currentTextChanged.connect(self.on_v_bitrate_changed) + video_layout.addRow("Video Bitrate:", self.v_bitrate_combo) + video_layout.addRow(self.v_bitrate_custom_edit) + self.video_group.setLayout(video_layout) + self.layout.addWidget(self.video_group) + + self.audio_group = QGroupBox("Audio Settings") + audio_layout = QFormLayout() + self.audio_codec_combo = QComboBox() + audio_layout.addRow("Audio Codec:", self.audio_codec_combo) + self.a_bitrate_combo = QComboBox() + self.a_bitrate_combo.addItems(self.audio_bitrate_options) + self.a_bitrate_custom_edit = QLineEdit() + self.a_bitrate_custom_edit.setPlaceholderText("e.g., 256k") + self.a_bitrate_custom_edit.hide() + self.a_bitrate_combo.currentTextChanged.connect(self.on_a_bitrate_changed) + audio_layout.addRow("Audio Bitrate:", self.a_bitrate_combo) + audio_layout.addRow(self.a_bitrate_custom_edit) + self.audio_group.setLayout(audio_layout) + self.layout.addWidget(self.audio_group) + + self.button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + self.button_box.accepted.connect(self.accept) + self.button_box.rejected.connect(self.reject) + self.layout.addWidget(self.button_box) + + def _populate_combo(self, combo, data_dict, filter_keys=None): + current_selection = combo.currentData() + combo.blockSignals(True) + combo.clear() + + keys_to_show = filter_keys if filter_keys is not None else data_dict.keys() + for codename in keys_to_show: + if codename in data_dict: + desc = data_dict[codename] + display_name = self.display_ext_map.get(codename, codename) if combo is self.container_combo else codename + combo.addItem(f"{desc} ({display_name})", codename) + + new_index = combo.findData(current_selection) + combo.setCurrentIndex(new_index if new_index != -1 else 0) + combo.blockSignals(False) + + def on_advanced_toggled(self, checked): + self._populate_combo(self.container_combo, self.formats, None if checked else CONTAINER_PRESETS.keys()) + + if not checked: + mp4_index = self.container_combo.findData("mp4") + if mp4_index != -1: self.container_combo.setCurrentIndex(mp4_index) + + self.on_container_changed(self.container_combo.currentIndex()) + + def apply_preset(self, container_codename): + preset = CONTAINER_PRESETS.get(container_codename, {}) + + vcodec = preset.get('vcodec') + self.video_group.setEnabled(vcodec is not None) + if vcodec: + vcodec_idx = self.video_codec_combo.findData(vcodec) + self.video_codec_combo.setCurrentIndex(vcodec_idx if vcodec_idx != -1 else 0) + + v_bitrate = preset.get('v_bitrate') + self.v_bitrate_combo.setEnabled(v_bitrate is not None) + if v_bitrate: + v_bitrate_idx = self.v_bitrate_combo.findText(v_bitrate) + if v_bitrate_idx != -1: self.v_bitrate_combo.setCurrentIndex(v_bitrate_idx) + else: + self.v_bitrate_combo.setCurrentText("Custom...") + self.v_bitrate_custom_edit.setText(v_bitrate) + + acodec = preset.get('acodec') + self.audio_group.setEnabled(acodec is not None) + if acodec: + acodec_idx = self.audio_codec_combo.findData(acodec) + self.audio_codec_combo.setCurrentIndex(acodec_idx if acodec_idx != -1 else 0) + + a_bitrate = preset.get('a_bitrate') + self.a_bitrate_combo.setEnabled(a_bitrate is not None) + if a_bitrate: + a_bitrate_idx = self.a_bitrate_combo.findText(a_bitrate) + if a_bitrate_idx != -1: self.a_bitrate_combo.setCurrentIndex(a_bitrate_idx) + else: + self.a_bitrate_combo.setCurrentText("Custom...") + self.a_bitrate_custom_edit.setText(a_bitrate) + + def on_v_bitrate_changed(self, text): + self.v_bitrate_custom_edit.setVisible(text == "Custom...") + + def on_a_bitrate_changed(self, text): + self.a_bitrate_custom_edit.setVisible(text == "Custom...") + + def browse_output_path(self): + container_codename = self.container_combo.currentData() + all_formats_desc = [f"{desc} (*.{name})" for name, desc in self.formats.items()] + filter_str = ";;".join(all_formats_desc) + current_desc = self.formats.get(container_codename, "Custom Format") + specific_filter = f"{current_desc} (*.{container_codename})" + final_filter = f"{specific_filter};;{filter_str};;All Files (*)" + + path, _ = QFileDialog.getSaveFileName(self, "Save Video As", self.path_edit.text(), final_filter) + if path: self.path_edit.setText(path) + + def on_container_changed(self, index): + if index == -1: return + new_container_codename = self.container_combo.itemData(index) + + is_advanced = self.advanced_formats_checkbox.isChecked() + preset = CONTAINER_PRESETS.get(new_container_codename) + + if not is_advanced and preset: + v_filter = preset.get('allowed_vcodecs') + a_filter = preset.get('allowed_acodecs') + self._populate_combo(self.video_codec_combo, self.video_codecs, v_filter) + self._populate_combo(self.audio_codec_combo, self.audio_codecs, a_filter) + else: + self._populate_combo(self.video_codec_combo, self.video_codecs) + self._populate_combo(self.audio_codec_combo, self.audio_codecs) + + if new_container_codename: + self.update_output_path_extension(new_container_codename) + self.apply_preset(new_container_codename) + + def update_output_path_extension(self, new_container_codename): + current_path = self.path_edit.text() + if not current_path: return + directory, filename = os.path.split(current_path) + basename, _ = os.path.splitext(filename) + ext = self.display_ext_map.get(new_container_codename, new_container_codename) + new_path = os.path.join(directory, f"{basename}.{ext}") + self.path_edit.setText(new_path) + + def get_export_settings(self): + v_bitrate = self.v_bitrate_combo.currentText() + if v_bitrate == "Custom...": v_bitrate = self.v_bitrate_custom_edit.text() + + a_bitrate = self.a_bitrate_combo.currentText() + if a_bitrate == "Custom...": a_bitrate = self.a_bitrate_custom_edit.text() + + return { + "output_path": self.path_edit.text(), + "container": self.container_combo.currentData(), + "vcodec": self.video_codec_combo.currentData() if self.video_group.isEnabled() else None, + "v_bitrate": v_bitrate if self.v_bitrate_combo.isEnabled() else None, + "acodec": self.audio_codec_combo.currentData() if self.audio_group.isEnabled() else None, + "a_bitrate": a_bitrate if self.a_bitrate_combo.isEnabled() else None, } class MediaListWidget(QListWidget): @@ -1391,6 +1768,7 @@ def __init__(self, project_to_load=None): self.export_thread = None self.export_worker = None self.current_project_path = None + self.last_export_path = None self.settings = {} self.settings_file = "settings.json" self.is_shutting_down = False @@ -1975,7 +2353,7 @@ def advance_playback_frame(self): def _load_settings(self): self.settings_file_was_loaded = False - defaults = {"window_visibility": {"project_media": False}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True} + defaults = {"window_visibility": {"project_media": False}, "splitter_state": None, "enabled_plugins": [], "recent_files": [], "confirm_on_exit": True, "default_export_path": ""} if os.path.exists(self.settings_file): try: with open(self.settings_file, "r") as f: self.settings = json.load(f) @@ -2026,6 +2404,7 @@ def new_project(self): self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1 self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list() self.current_project_path = None; self.stop_playback() + self.last_export_path = None self.project_fps = 25.0 self.project_width = 1280 self.project_height = 720 @@ -2044,6 +2423,7 @@ def save_project_as(self): "media_pool": self.media_pool, "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips], "selection_regions": self.timeline_widget.selection_regions, + "last_export_path": self.last_export_path, "settings": { "num_video_tracks": self.timeline.num_video_tracks, "num_audio_tracks": self.timeline.num_audio_tracks, @@ -2074,11 +2454,11 @@ def _load_project_from_path(self, path): self.project_height = project_settings.get("project_height", 720) self.project_fps = project_settings.get("project_fps", 25.0) self.timeline_widget.set_project_fps(self.project_fps) - + self.last_export_path = project_data.get("last_export_path") self.timeline_widget.selection_regions = project_data.get("selection_regions", []) - self.media_pool = project_data.get("media_pool", []) - for p in self.media_pool: self._add_media_to_pool(p) + media_pool_paths = project_data.get("media_pool", []) + for p in media_pool_paths: self._add_media_to_pool(p) for clip_data in project_data["clips"]: if not os.path.exists(clip_data["source_path"]): @@ -2520,9 +2900,41 @@ def action(): def export_video(self): - if not self.timeline.clips: self.status_label.setText("Timeline is empty."); return - output_path, _ = QFileDialog.getSaveFileName(self, "Save Video As", "", "MP4 Files (*.mp4)") - if not output_path: return + if not self.timeline.clips: + self.status_label.setText("Timeline is empty.") + return + + default_path = "" + if self.last_export_path and os.path.isdir(os.path.dirname(self.last_export_path)): + default_path = self.last_export_path + elif self.settings.get("default_export_path") and os.path.isdir(self.settings.get("default_export_path")): + proj_basename = "output" + if self.current_project_path: + _, proj_file = os.path.split(self.current_project_path) + proj_basename, _ = os.path.splitext(proj_file) + + default_path = os.path.join(self.settings["default_export_path"], f"{proj_basename}_export.mp4") + elif self.current_project_path: + proj_dir, proj_file = os.path.split(self.current_project_path) + proj_basename, _ = os.path.splitext(proj_file) + default_path = os.path.join(proj_dir, f"{proj_basename}_export.mp4") + else: + default_path = "output.mp4" + + default_path = os.path.normpath(default_path) + + dialog = ExportDialog(default_path, self) + if dialog.exec() != QDialog.DialogCode.Accepted: + self.status_label.setText("Export canceled.") + return + + settings = dialog.get_export_settings() + output_path = settings["output_path"] + if not output_path: + self.status_label.setText("Export failed: No output path specified.") + return + + self.last_export_path = output_path w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration() sample_rate, channel_layout = '44100', 'stereo' @@ -2533,9 +2945,7 @@ def export_video(self): [c for c in self.timeline.clips if c.track_type == 'video'], key=lambda c: c.track_index ) - input_nodes = {} - for clip in all_video_clips: if clip.source_path not in input_nodes: if clip.media_type == 'image': @@ -2544,31 +2954,13 @@ def export_video(self): input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) clip_source_node = input_nodes[clip.source_path] - if clip.media_type == 'image': - segment_stream = ( - clip_source_node.video - .trim(duration=clip.duration_sec) - .setpts('PTS-STARTPTS') - ) + segment_stream = (clip_source_node.video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS')) else: - segment_stream = ( - clip_source_node.video - .trim(start=clip.clip_start_sec, duration=clip.duration_sec) - .setpts('PTS-STARTPTS') - ) + segment_stream = (clip_source_node.video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS')) - processed_segment = ( - segment_stream - .filter('scale', w, h, force_original_aspect_ratio='decrease') - .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') - ) - - video_stream = ffmpeg.overlay( - video_stream, - processed_segment, - enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})' - ) + processed_segment = (segment_stream.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')) + video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})') final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps) @@ -2576,22 +2968,17 @@ def export_video(self): for i in range(self.timeline.num_audio_tracks): track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec) if not track_clips: continue - track_segments = [] last_end = track_clips[0].timeline_start_sec - for clip in track_clips: if clip.source_path not in input_nodes: input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) - gap = clip.timeline_start_sec - last_end if gap > 0.01: track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi')) - a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS') track_segments.append(a_seg) last_end = clip.timeline_end_sec - track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_sec * 1000)}ms', all=True)) if track_audio_streams: @@ -2599,7 +2986,10 @@ def export_video(self): else: final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur}', f='lavfi') - output_args = {'vcodec': 'libx264', 'acodec': 'aac', 'pix_fmt': 'yuv420p', 'b:v': '5M'} + output_args = {'vcodec': settings['vcodec'], 'acodec': settings['acodec'], 'pix_fmt': 'yuv420p'} + if settings['v_bitrate']: output_args['b:v'] = settings['v_bitrate'] + if settings['a_bitrate']: output_args['b:a'] = settings['a_bitrate'] + try: ffmpeg_cmd = ffmpeg.output(final_video, final_audio, output_path, **output_args).overwrite_output().compile() self.progress_bar.setVisible(True); self.progress_bar.setValue(0); self.status_label.setText("Exporting...") From cdcde06778a07ee55be95f6bc64dd9d1e9e4bfd3 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 21:26:42 +1100 Subject: [PATCH 141/155] use ms instead of sec globally --- plugins/wan2gp/main.py | 67 ++--- videoeditor.py | 612 +++++++++++++++++++++-------------------- 2 files changed, 347 insertions(+), 332 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 623d4afa7..473baf2e4 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -124,11 +124,12 @@ def enterEvent(self, event): super().enterEvent(event) self.media_player.play() if not self.plugin.active_region or self.duration == 0: return - start_sec, _ = self.plugin.active_region + start_ms, _ = self.plugin.active_region timeline = self.app.timeline_widget video_rect, audio_rect = None, None - x = timeline.sec_to_x(start_sec) - w = int(self.duration * timeline.pixels_per_second) + x = timeline.ms_to_x(start_ms) + duration_ms = int(self.duration * 1000) + w = int(duration_ms * timeline.pixels_per_ms) if self.plugin.insert_on_new_track: video_y = timeline.TIMESCALE_HEIGHT video_rect = QRectF(x, video_y, w, timeline.TRACK_HEIGHT) @@ -1705,9 +1706,9 @@ def on_timeline_context_menu(self, menu, event): region = self.app.timeline_widget.get_region_at_pos(event.pos()) if region: menu.addSeparator() - start_sec, end_sec = region - start_data, _, _ = self.app.get_frame_data_at_time(start_sec) - end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + start_ms, end_ms = region + start_data, _, _ = self.app.get_frame_data_at_time(start_ms) + end_data, _, _ = self.app.get_frame_data_at_time(end_ms) if start_data and end_data: join_action = menu.addAction("Join Frames With AI") @@ -1745,9 +1746,9 @@ def setup_generator_for_region(self, region, on_new_track=False): else: print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") - start_sec, end_sec = region - start_data, w, h = self.app.get_frame_data_at_time(start_sec) - end_data, _, _ = self.app.get_frame_data_at_time(end_sec) + start_ms, end_ms = region + start_data, w, h = self.app.get_frame_data_at_time(start_ms) + end_data, _, _ = self.app.get_frame_data_at_time(end_ms) if not start_data or not end_data: QMessageBox.warning(self.app, "Frame Error", "Could not extract start and/or end frames.") return @@ -1761,11 +1762,11 @@ def setup_generator_for_region(self, region, on_new_track=False): QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path) QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) - duration_sec = end_sec - start_sec + duration_ms = end_ms - start_ms wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) - video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) widgets = self.client_widget.widgets for w_name in ['mode_s', 'mode_t', 'mode_v', 'mode_l', 'image_end_checkbox']: @@ -1786,7 +1787,7 @@ def setup_generator_for_region(self, region, on_new_track=False): QMessageBox.critical(self.app, "File Error", f"Could not save temporary frame images: {e}") self._cleanup_temp_dir() return - self.app.status_label.setText(f"Ready to join frames from {start_sec:.2f}s to {end_sec:.2f}s.") + self.app.status_label.setText(f"Ready to join frames from {start_ms / 1000.0:.2f}s to {end_ms / 1000.0:.2f}s.") self.dock_widget.show() self.dock_widget.raise_() @@ -1805,8 +1806,8 @@ def setup_generator_from_start(self, region, on_new_track=False): else: print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") - start_sec, end_sec = region - start_data, w, h = self.app.get_frame_data_at_time(start_sec) + start_ms, end_ms = region + start_data, w, h = self.app.get_frame_data_at_time(start_ms) if not start_data: QMessageBox.warning(self.app, "Frame Error", "Could not extract start frame.") return @@ -1818,11 +1819,11 @@ def setup_generator_from_start(self, region, on_new_track=False): self.start_frame_path = os.path.join(self.temp_dir, "start_frame.png") QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path) - duration_sec = end_sec - start_sec + duration_ms = end_ms - start_ms wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) - video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) widgets = self.client_widget.widgets widgets['video_length'].setValue(video_length_frames) @@ -1845,7 +1846,7 @@ def setup_generator_from_start(self, region, on_new_track=False): self._cleanup_temp_dir() return - self.app.status_label.setText(f"Ready to generate from frame at {start_sec:.2f}s.") + self.app.status_label.setText(f"Ready to generate from frame at {start_ms / 1000.0:.2f}s.") self.dock_widget.show() self.dock_widget.raise_() @@ -1864,8 +1865,8 @@ def setup_generator_to_end(self, region, on_new_track=False): else: print(f"Warning: Default model '{model_to_set}' not found for AI Joiner. Using current model.") - start_sec, end_sec = region - end_data, w, h = self.app.get_frame_data_at_time(end_sec) + start_ms, end_ms = region + end_data, w, h = self.app.get_frame_data_at_time(end_ms) if not end_data: QMessageBox.warning(self.app, "Frame Error", "Could not extract end frame.") return @@ -1877,11 +1878,11 @@ def setup_generator_to_end(self, region, on_new_track=False): self.end_frame_path = os.path.join(self.temp_dir, "end_frame.png") QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) - duration_sec = end_sec - start_sec + duration_ms = end_ms - start_ms wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) - video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) widgets = self.client_widget.widgets widgets['video_length'].setValue(video_length_frames) @@ -1915,7 +1916,7 @@ def setup_generator_to_end(self, region, on_new_track=False): self._cleanup_temp_dir() return - self.app.status_label.setText(f"Ready to generate to frame at {end_sec:.2f}s.") + self.app.status_label.setText(f"Ready to generate to frame at {end_ms / 1000.0:.2f}s.") self.dock_widget.show() self.dock_widget.raise_() @@ -1938,17 +1939,17 @@ def setup_creator_for_region(self, region, on_new_track=False): target_h = self.app.project_height self.client_widget.set_resolution_from_target(target_w, target_h) - start_sec, end_sec = region - duration_sec = end_sec - start_sec + start_ms, end_ms = region + duration_ms = end_ms - start_ms wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) - video_length_frames = int(duration_sec * fps) if fps > 0 else int(duration_sec * 16) + video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) self.client_widget.widgets['video_length'].setValue(video_length_frames) self.client_widget.widgets['mode_t'].setChecked(True) - self.app.status_label.setText(f"Ready to create video from {start_sec:.2f}s to {end_sec:.2f}s.") + self.app.status_label.setText(f"Ready to create video from {start_ms / 1000.0:.2f}s to {end_ms / 1000.0:.2f}s.") self.dock_widget.show() self.dock_widget.raise_() @@ -1958,29 +1959,29 @@ def insert_generated_clip(self, video_path): self.app.status_label.setText("Error: No active region to insert into."); return if not os.path.exists(video_path): self.app.status_label.setText(f"Error: Output file not found: {video_path}"); return - start_sec, end_sec = self.active_region + start_ms, end_ms = self.active_region def complex_insertion_action(): self.app._add_media_files_to_project([video_path]) media_info = self.app.media_properties.get(video_path) if not media_info: raise ValueError("Could not probe inserted clip.") - actual_duration, has_audio = media_info['duration'], media_info['has_audio'] + actual_duration_ms, has_audio = media_info['duration_ms'], media_info['has_audio'] if self.insert_on_new_track: self.app.timeline.num_video_tracks += 1 video_track_index = self.app.timeline.num_video_tracks audio_track_index = self.app.timeline.num_audio_tracks + 1 if has_audio else None else: - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_sec) - for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_sec) - clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_end_sec <= end_sec] + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, start_ms) + for clip in list(self.app.timeline.clips): self.app._split_at_time(clip, end_ms) + clips_to_remove = [c for c in self.app.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_end_ms <= end_ms] for clip in clips_to_remove: if clip in self.app.timeline.clips: self.app.timeline.clips.remove(clip) video_track_index, audio_track_index = 1, 1 if has_audio else None group_id = str(uuid.uuid4()) - new_clip = TimelineClip(video_path, start_sec, 0, actual_duration, video_track_index, 'video', 'video', group_id) + new_clip = TimelineClip(video_path, start_ms, 0, actual_duration_ms, video_track_index, 'video', 'video', group_id) self.app.timeline.add_clip(new_clip) if audio_track_index: if audio_track_index > self.app.timeline.num_audio_tracks: self.app.timeline.num_audio_tracks = audio_track_index - audio_clip = TimelineClip(video_path, start_sec, 0, actual_duration, audio_track_index, 'audio', 'video', group_id) + audio_clip = TimelineClip(video_path, start_ms, 0, actual_duration_ms, audio_track_index, 'audio', 'video', group_id) self.app.timeline.add_clip(audio_clip) try: self.app._perform_complex_timeline_change("Insert AI Clip", complex_insertion_action) diff --git a/videoeditor.py b/videoeditor.py index b6673c879..5126b79b5 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -181,20 +181,20 @@ def get_available_codecs(codec_type='video'): return _cached_video_codecs if codec_type == 'video' else _cached_audio_codecs class TimelineClip: - def __init__(self, source_path, timeline_start_sec, clip_start_sec, duration_sec, track_index, track_type, media_type, group_id): + def __init__(self, source_path, timeline_start_ms, clip_start_ms, duration_ms, track_index, track_type, media_type, group_id): self.id = str(uuid.uuid4()) self.source_path = source_path - self.timeline_start_sec = timeline_start_sec - self.clip_start_sec = clip_start_sec - self.duration_sec = duration_sec + self.timeline_start_ms = int(timeline_start_ms) + self.clip_start_ms = int(clip_start_ms) + self.duration_ms = int(duration_ms) self.track_index = track_index self.track_type = track_type self.media_type = media_type self.group_id = group_id @property - def timeline_end_sec(self): - return self.timeline_start_sec + self.duration_sec + def timeline_end_ms(self): + return self.timeline_start_ms + self.duration_ms class Timeline: def __init__(self): @@ -204,19 +204,20 @@ def __init__(self): def add_clip(self, clip): self.clips.append(clip) - self.clips.sort(key=lambda c: c.timeline_start_sec) + self.clips.sort(key=lambda c: c.timeline_start_ms) def get_total_duration(self): if not self.clips: return 0 - return max(c.timeline_end_sec for c in self.clips) + return max(c.timeline_end_ms for c in self.clips) class ExportWorker(QObject): progress = pyqtSignal(int) finished = pyqtSignal(str) - def __init__(self, ffmpeg_cmd, total_duration): + def __init__(self, ffmpeg_cmd, total_duration_ms): super().__init__() self.ffmpeg_cmd = ffmpeg_cmd - self.total_duration_secs = total_duration + self.total_duration_ms = total_duration_ms + def run_export(self): try: process = subprocess.Popen(self.ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, encoding="utf-8") @@ -224,10 +225,10 @@ def run_export(self): for line in iter(process.stdout.readline, ""): match = time_pattern.search(line) if match: - hours, minutes, seconds = [int(g) for g in match.groups()[:3]] - processed_time = hours * 3600 + minutes * 60 + seconds - if self.total_duration_secs > 0: - percentage = int((processed_time / self.total_duration_secs) * 100) + h, m, s, cs = [int(g) for g in match.groups()] + processed_ms = (h * 3600 + m * 60 + s) * 1000 + cs * 10 + if self.total_duration_ms > 0: + percentage = int((processed_ms / self.total_duration_ms) * 100) self.progress.emit(percentage) process.stdout.close() return_code = process.wait() @@ -250,7 +251,7 @@ class TimelineWidget(QWidget): split_requested = pyqtSignal(object) delete_clip_requested = pyqtSignal(object) delete_clips_requested = pyqtSignal(list) - playhead_moved = pyqtSignal(float) + playhead_moved = pyqtSignal(int) split_region_requested = pyqtSignal(list) split_all_regions_requested = pyqtSignal(list) join_region_requested = pyqtSignal(list) @@ -266,14 +267,14 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): super().__init__(parent) self.timeline = timeline_model self.settings = settings - self.playhead_pos_sec = 0.0 - self.view_start_sec = 0.0 + self.playhead_pos_ms = 0 + self.view_start_ms = 0 self.panning = False self.pan_start_pos = QPoint() - self.pan_start_view_sec = 0.0 + self.pan_start_view_ms = 0 - self.pixels_per_second = 50.0 - self.max_pixels_per_second = 1.0 + self.pixels_per_ms = 0.05 + self.max_pixels_per_ms = 1.0 self.project_fps = 25.0 self.set_project_fps(project_fps) @@ -290,7 +291,7 @@ def __init__(self, timeline_model, settings, project_fps, parent=None): self.dragging_selection_region = None self.drag_start_pos = QPoint() self.drag_original_clip_states = {} - self.selection_drag_start_sec = 0.0 + self.selection_drag_start_ms = 0 self.drag_selection_start_values = None self.drag_start_state = None @@ -325,12 +326,12 @@ def set_hover_preview_rects(self, video_rect, audio_rect): def set_project_fps(self, fps): self.project_fps = fps if fps > 0 else 25.0 - self.max_pixels_per_second = self.project_fps * 20 - self.pixels_per_second = min(self.pixels_per_second, self.max_pixels_per_second) + self.max_pixels_per_ms = (self.project_fps * 20) / 1000.0 + self.pixels_per_ms = min(self.pixels_per_ms, self.max_pixels_per_ms) self.update() - def sec_to_x(self, sec): return self.HEADER_WIDTH + int((sec - self.view_start_sec) * self.pixels_per_second) - def x_to_sec(self, x): return self.view_start_sec + float(x - self.HEADER_WIDTH) / self.pixels_per_second if x > self.HEADER_WIDTH and self.pixels_per_second > 0 else self.view_start_sec + def ms_to_x(self, ms): return self.HEADER_WIDTH + int((ms - self.view_start_ms) * self.pixels_per_ms) + def x_to_ms(self, x): return self.view_start_ms + int(float(x - self.HEADER_WIDTH) / self.pixels_per_ms) if x > self.HEADER_WIDTH and self.pixels_per_ms > 0 else self.view_start_ms def paintEvent(self, event): painter = QPainter(self) @@ -438,11 +439,12 @@ def draw_headers(self, painter): painter.restore() - def _format_timecode(self, seconds): - if abs(seconds) < 1e-9: seconds = 0 - sign = "-" if seconds < 0 else "" - seconds = abs(seconds) + def _format_timecode(self, total_ms): + if abs(total_ms) < 1: total_ms = 0 + sign = "-" if total_ms < 0 else "" + total_ms = abs(total_ms) + seconds = total_ms / 1000.0 h = int(seconds / 3600) m = int((seconds % 3600) / 60) s = seconds % 60 @@ -464,51 +466,51 @@ def draw_timescale(self, painter): painter.fillRect(QRect(self.HEADER_WIDTH, 0, self.width() - self.HEADER_WIDTH, self.TIMESCALE_HEIGHT), QColor("#222")) painter.drawLine(self.HEADER_WIDTH, self.TIMESCALE_HEIGHT - 1, self.width(), self.TIMESCALE_HEIGHT - 1) - frame_dur = 1.0 / self.project_fps - intervals = [ - frame_dur, 2*frame_dur, 5*frame_dur, 10*frame_dur, - 0.1, 0.2, 0.5, 1, 2, 5, 10, 15, 30, - 60, 120, 300, 600, 900, 1800, - 3600, 2*3600, 5*3600, 10*3600 + frame_dur_ms = 1000.0 / self.project_fps + intervals_ms = [ + int(frame_dur_ms), int(2*frame_dur_ms), int(5*frame_dur_ms), int(10*frame_dur_ms), + 100, 200, 500, 1000, 2000, 5000, 10000, 15000, 30000, + 60000, 120000, 300000, 600000, 900000, 1800000, + 3600000, 2*3600000, 5*3600000, 10*3600000 ] min_pixel_dist = 70 - major_interval = next((i for i in intervals if i * self.pixels_per_second > min_pixel_dist), intervals[-1]) + major_interval = next((i for i in intervals_ms if i * self.pixels_per_ms > min_pixel_dist), intervals_ms[-1]) minor_interval = 0 for divisor in [5, 4, 2]: - if (major_interval / divisor) * self.pixels_per_second > 10: + if (major_interval / divisor) * self.pixels_per_ms > 10: minor_interval = major_interval / divisor break - start_sec = self.x_to_sec(self.HEADER_WIDTH) - end_sec = self.x_to_sec(self.width()) + start_ms = self.x_to_ms(self.HEADER_WIDTH) + end_ms = self.x_to_ms(self.width()) - def draw_ticks(interval, height): - if interval <= 1e-6: return - start_tick_num = int(start_sec / interval) - end_tick_num = int(end_sec / interval) + 1 + def draw_ticks(interval_ms, height): + if interval_ms < 1: return + start_tick_num = int(start_ms / interval_ms) + end_tick_num = int(end_ms / interval_ms) + 1 for i in range(start_tick_num, end_tick_num + 1): - t_sec = i * interval - x = self.sec_to_x(t_sec) + t_ms = i * interval_ms + x = self.ms_to_x(t_ms) if x > self.width(): break if x >= self.HEADER_WIDTH: painter.drawLine(x, self.TIMESCALE_HEIGHT - height, x, self.TIMESCALE_HEIGHT) - if frame_dur * self.pixels_per_second > 4: - draw_ticks(frame_dur, 3) + if frame_dur_ms * self.pixels_per_ms > 4: + draw_ticks(frame_dur_ms, 3) if minor_interval > 0: draw_ticks(minor_interval, 6) - start_major_tick = int(start_sec / major_interval) - end_major_tick = int(end_sec / major_interval) + 1 + start_major_tick = int(start_ms / major_interval) + end_major_tick = int(end_ms / major_interval) + 1 for i in range(start_major_tick, end_major_tick + 1): - t_sec = i * major_interval - x = self.sec_to_x(t_sec) + t_ms = i * major_interval + x = self.ms_to_x(t_ms) if x > self.width() + 50: break if x >= self.HEADER_WIDTH - 50: painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT) - label = self._format_timecode(t_sec) + label = self._format_timecode(t_ms) label_width = font_metrics.horizontalAdvance(label) label_x = x - label_width // 2 if label_x < self.HEADER_WIDTH: @@ -525,8 +527,8 @@ def get_clip_rect(self, clip): visual_index = clip.track_index - 1 y = self.audio_tracks_y_start + visual_index * self.TRACK_HEIGHT - x = self.sec_to_x(clip.timeline_start_sec) - w = int(clip.duration_sec * self.pixels_per_second) + x = self.ms_to_x(clip.timeline_start_ms) + w = int(clip.duration_ms * self.pixels_per_ms) clip_height = self.TRACK_HEIGHT - 10 y += (self.TRACK_HEIGHT - clip_height) / 2 return QRectF(x, y, w, clip_height) @@ -598,16 +600,16 @@ def draw_tracks_and_clips(self, painter): painter.restore() def draw_selections(self, painter): - for start_sec, end_sec in self.selection_regions: - x = self.sec_to_x(start_sec) - w = int((end_sec - start_sec) * self.pixels_per_second) + for start_ms, end_ms in self.selection_regions: + x = self.ms_to_x(start_ms) + w = int((end_ms - start_ms) * self.pixels_per_ms) selection_rect = QRectF(x, self.TIMESCALE_HEIGHT, w, self.height() - self.TIMESCALE_HEIGHT) painter.fillRect(selection_rect, QColor(100, 100, 255, 80)) painter.setPen(QColor(150, 150, 255, 150)) painter.drawRect(selection_rect) def draw_playhead(self, painter): - playhead_x = self.sec_to_x(self.playhead_pos_sec) + playhead_x = self.ms_to_x(self.playhead_pos_ms) painter.setPen(QPen(QColor("red"), 2)) painter.drawLine(playhead_x, 0, playhead_x, self.height()) @@ -634,48 +636,48 @@ def y_to_track_info(self, y): return None - def _snap_time_if_needed(self, time_sec): - frame_duration = 1.0 / self.project_fps - if frame_duration > 0 and frame_duration * self.pixels_per_second > 4: - frame_number = round(time_sec / frame_duration) - return frame_number * frame_duration - return time_sec + def _snap_time_if_needed(self, time_ms): + frame_duration_ms = 1000.0 / self.project_fps + if frame_duration_ms > 0 and frame_duration_ms * self.pixels_per_ms > 4: + frame_number = round(time_ms / frame_duration_ms) + return int(frame_number * frame_duration_ms) + return int(time_ms) def get_region_at_pos(self, pos: QPoint): if pos.y() <= self.TIMESCALE_HEIGHT or pos.x() <= self.HEADER_WIDTH: return None - clicked_sec = self.x_to_sec(pos.x()) + clicked_ms = self.x_to_ms(pos.x()) for region in reversed(self.selection_regions): - if region[0] <= clicked_sec <= region[1]: + if region[0] <= clicked_ms <= region[1]: return region return None def wheelEvent(self, event: QMouseEvent): delta = event.angleDelta().y() zoom_factor = 1.15 - old_pps = self.pixels_per_second + old_pps = self.pixels_per_ms if delta > 0: new_pps = old_pps * zoom_factor else: new_pps = old_pps / zoom_factor - min_pps = 1 / (3600 * 10) - new_pps = max(min_pps, min(new_pps, self.max_pixels_per_second)) + min_pps = 1 / (3600 * 10 * 1000) + new_pps = max(min_pps, min(new_pps, self.max_pixels_per_ms)) if abs(new_pps - old_pps) < 1e-9: return if event.position().x() < self.HEADER_WIDTH: - new_view_start_sec = self.view_start_sec * (old_pps / new_pps) + new_view_start_ms = self.view_start_ms * (old_pps / new_pps) else: mouse_x = event.position().x() - time_at_cursor = self.x_to_sec(mouse_x) - new_view_start_sec = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps + time_at_cursor = self.x_to_ms(mouse_x) + new_view_start_ms = time_at_cursor - (mouse_x - self.HEADER_WIDTH) / new_pps - self.pixels_per_second = new_pps - self.view_start_sec = max(0.0, new_view_start_sec) + self.pixels_per_ms = new_pps + self.view_start_ms = int(max(0, new_view_start_ms)) self.update() event.accept() @@ -684,7 +686,7 @@ def mousePressEvent(self, event: QMouseEvent): if event.button() == Qt.MouseButton.MiddleButton: self.panning = True self.pan_start_pos = event.pos() - self.pan_start_view_sec = self.view_start_sec + self.pan_start_view_ms = self.view_start_ms self.setCursor(Qt.CursorShape.ClosedHandCursor) event.accept() return @@ -713,8 +715,8 @@ def mousePressEvent(self, event: QMouseEvent): for region in self.selection_regions: if not region: continue - x_start = self.sec_to_x(region[0]) - x_end = self.sec_to_x(region[1]) + x_start = self.ms_to_x(region[0]) + x_end = self.ms_to_x(region[1]) if event.pos().y() > self.TIMESCALE_HEIGHT: if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH: self.resizing_selection_region = region @@ -768,12 +770,12 @@ def mousePressEvent(self, event: QMouseEvent): if clicked_clip.id in self.selected_clips: self.dragging_clip = clicked_clip self.drag_start_state = self.window()._get_current_timeline_state() - self.drag_original_clip_states[clicked_clip.id] = (clicked_clip.timeline_start_sec, clicked_clip.track_index) + self.drag_original_clip_states[clicked_clip.id] = (clicked_clip.timeline_start_ms, clicked_clip.track_index) self.dragging_linked_clip = next((c for c in self.timeline.clips if c.group_id == clicked_clip.group_id and c.id != clicked_clip.id), None) if self.dragging_linked_clip: self.drag_original_clip_states[self.dragging_linked_clip.id] = \ - (self.dragging_linked_clip.timeline_start_sec, self.dragging_linked_clip.track_index) + (self.dragging_linked_clip.timeline_start_ms, self.dragging_linked_clip.track_index) self.drag_start_pos = event.pos() else: @@ -789,16 +791,16 @@ def mousePressEvent(self, event: QMouseEvent): if is_in_track_area: self.creating_selection_region = True - playhead_x = self.sec_to_x(self.playhead_pos_sec) + playhead_x = self.ms_to_x(self.playhead_pos_ms) if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS: - self.selection_drag_start_sec = self.playhead_pos_sec + self.selection_drag_start_ms = self.playhead_pos_ms else: - self.selection_drag_start_sec = self.x_to_sec(event.pos().x()) - self.selection_regions.append([self.selection_drag_start_sec, self.selection_drag_start_sec]) + self.selection_drag_start_ms = self.x_to_ms(event.pos().x()) + self.selection_regions.append([self.selection_drag_start_ms, self.selection_drag_start_ms]) elif is_on_timescale: - time_sec = max(0, self.x_to_sec(event.pos().x())) - self.playhead_pos_sec = self._snap_time_if_needed(time_sec) - self.playhead_moved.emit(self.playhead_pos_sec) + time_ms = max(0, self.x_to_ms(event.pos().x())) + self.playhead_pos_ms = self._snap_time_if_needed(time_ms) + self.playhead_moved.emit(self.playhead_pos_ms) self.dragging_playhead = True self.update() @@ -806,22 +808,22 @@ def mousePressEvent(self, event: QMouseEvent): def mouseMoveEvent(self, event: QMouseEvent): if self.panning: delta_x = event.pos().x() - self.pan_start_pos.x() - time_delta = delta_x / self.pixels_per_second - new_view_start = self.pan_start_view_sec - time_delta - self.view_start_sec = max(0.0, new_view_start) + time_delta = delta_x / self.pixels_per_ms + new_view_start = self.pan_start_view_ms - time_delta + self.view_start_ms = int(max(0, new_view_start)) self.update() return if self.resizing_selection_region: - current_sec = max(0, self.x_to_sec(event.pos().x())) + current_ms = max(0, self.x_to_ms(event.pos().x())) original_start, original_end = self.resize_selection_start_values if self.resize_selection_edge == 'left': - new_start = current_sec + new_start = current_ms new_end = original_end else: # right new_start = original_start - new_end = current_sec + new_end = current_ms self.resizing_selection_region[0] = min(new_start, new_end) self.resizing_selection_region[1] = max(new_start, new_end) @@ -836,60 +838,60 @@ def mouseMoveEvent(self, event: QMouseEvent): if self.resizing_clip: linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None) delta_x = event.pos().x() - self.resize_start_pos.x() - time_delta = delta_x / self.pixels_per_second - min_duration = 1.0 / self.project_fps - snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second + time_delta = delta_x / self.pixels_per_ms + min_duration_ms = int(1000 / self.project_fps) + snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms - snap_points = [self.playhead_pos_sec] + snap_points = [self.playhead_pos_ms] for clip in self.timeline.clips: if clip.id == self.resizing_clip.id: continue if linked_clip and clip.id == linked_clip.id: continue - snap_points.append(clip.timeline_start_sec) - snap_points.append(clip.timeline_end_sec) + snap_points.append(clip.timeline_start_ms) + snap_points.append(clip.timeline_end_ms) media_props = self.window().media_properties.get(self.resizing_clip.source_path) - source_duration = media_props['duration'] if media_props else float('inf') + source_duration_ms = media_props['duration_ms'] if media_props else float('inf') if self.resize_edge == 'left': - original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec - original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec - original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_sec - true_new_start_sec = original_start + time_delta + original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_ms + original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_ms + original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_ms + true_new_start_ms = original_start + time_delta - new_start_sec = true_new_start_sec + new_start_ms = true_new_start_ms for snap_point in snap_points: - if abs(true_new_start_sec - snap_point) < snap_time_delta: - new_start_sec = snap_point + if abs(true_new_start_ms - snap_point) < snap_time_delta: + new_start_ms = snap_point break - if new_start_sec > original_start + original_duration - min_duration: - new_start_sec = original_start + original_duration - min_duration + if new_start_ms > original_start + original_duration - min_duration_ms: + new_start_ms = original_start + original_duration - min_duration_ms - new_start_sec = max(0, new_start_sec) + new_start_ms = max(0, new_start_ms) if self.resizing_clip.media_type != 'image': - if new_start_sec < original_start - original_clip_start: - new_start_sec = original_start - original_clip_start + if new_start_ms < original_start - original_clip_start: + new_start_ms = original_start - original_clip_start - new_duration = (original_start + original_duration) - new_start_sec - new_clip_start = original_clip_start + (new_start_sec - original_start) + new_duration = (original_start + original_duration) - new_start_ms + new_clip_start = original_clip_start + (new_start_ms - original_start) - if new_duration < min_duration: - new_duration = min_duration - new_start_sec = (original_start + original_duration) - new_duration - new_clip_start = original_clip_start + (new_start_sec - original_start) - - self.resizing_clip.timeline_start_sec = new_start_sec - self.resizing_clip.duration_sec = new_duration - self.resizing_clip.clip_start_sec = new_clip_start + if new_duration < min_duration_ms: + new_duration = min_duration_ms + new_start_ms = (original_start + original_duration) - new_duration + new_clip_start = original_clip_start + (new_start_ms - original_start) + + self.resizing_clip.timeline_start_ms = int(new_start_ms) + self.resizing_clip.duration_ms = int(new_duration) + self.resizing_clip.clip_start_ms = int(new_clip_start) if linked_clip: - linked_clip.timeline_start_sec = new_start_sec - linked_clip.duration_sec = new_duration - linked_clip.clip_start_sec = new_clip_start + linked_clip.timeline_start_ms = int(new_start_ms) + linked_clip.duration_ms = int(new_duration) + linked_clip.clip_start_ms = int(new_clip_start) elif self.resize_edge == 'right': - original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_sec - original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_sec + original_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].timeline_start_ms + original_duration = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].duration_ms true_new_duration = original_duration + time_delta true_new_end_time = original_start + true_new_duration @@ -902,23 +904,23 @@ def mouseMoveEvent(self, event: QMouseEvent): new_duration = new_end_time - original_start - if new_duration < min_duration: - new_duration = min_duration + if new_duration < min_duration_ms: + new_duration = min_duration_ms if self.resizing_clip.media_type != 'image': - if self.resizing_clip.clip_start_sec + new_duration > source_duration: - new_duration = source_duration - self.resizing_clip.clip_start_sec + if self.resizing_clip.clip_start_ms + new_duration > source_duration_ms: + new_duration = source_duration_ms - self.resizing_clip.clip_start_ms - self.resizing_clip.duration_sec = new_duration + self.resizing_clip.duration_ms = int(new_duration) if linked_clip: - linked_clip.duration_sec = new_duration + linked_clip.duration_ms = int(new_duration) self.update() return if not self.dragging_clip and not self.dragging_playhead and not self.creating_selection_region: cursor_set = False - playhead_x = self.sec_to_x(self.playhead_pos_sec) + playhead_x = self.ms_to_x(self.playhead_pos_ms) is_in_track_area = event.pos().y() > self.TIMESCALE_HEIGHT and event.pos().x() > self.HEADER_WIDTH if is_in_track_area and abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS: self.setCursor(Qt.CursorShape.SizeHorCursor) @@ -926,8 +928,8 @@ def mouseMoveEvent(self, event: QMouseEvent): if not cursor_set and is_in_track_area: for region in self.selection_regions: - x_start = self.sec_to_x(region[0]) - x_end = self.sec_to_x(region[1]) + x_start = self.ms_to_x(region[0]) + x_end = self.ms_to_x(region[1]) if abs(event.pos().x() - x_start) < self.RESIZE_HANDLE_WIDTH or \ abs(event.pos().x() - x_end) < self.RESIZE_HANDLE_WIDTH: self.setCursor(Qt.CursorShape.SizeHorCursor) @@ -945,16 +947,16 @@ def mouseMoveEvent(self, event: QMouseEvent): self.unsetCursor() if self.creating_selection_region: - current_sec = self.x_to_sec(event.pos().x()) - start = min(self.selection_drag_start_sec, current_sec) - end = max(self.selection_drag_start_sec, current_sec) + current_ms = self.x_to_ms(event.pos().x()) + start = min(self.selection_drag_start_ms, current_ms) + end = max(self.selection_drag_start_ms, current_ms) self.selection_regions[-1] = [start, end] self.update() return if self.dragging_selection_region: delta_x = event.pos().x() - self.drag_start_pos.x() - time_delta = delta_x / self.pixels_per_second + time_delta = int(delta_x / self.pixels_per_ms) original_start, original_end = self.drag_selection_start_values duration = original_end - original_start @@ -967,16 +969,16 @@ def mouseMoveEvent(self, event: QMouseEvent): return if self.dragging_playhead: - time_sec = max(0, self.x_to_sec(event.pos().x())) - self.playhead_pos_sec = self._snap_time_if_needed(time_sec) - self.playhead_moved.emit(self.playhead_pos_sec) + time_ms = max(0, self.x_to_ms(event.pos().x())) + self.playhead_pos_ms = self._snap_time_if_needed(time_ms) + self.playhead_moved.emit(self.playhead_pos_ms) self.update() elif self.dragging_clip: self.highlighted_track_info = None self.highlighted_ghost_track_info = None new_track_info = self.y_to_track_info(event.pos().y()) - original_start_sec, _ = self.drag_original_clip_states[self.dragging_clip.id] + original_start_ms, _ = self.drag_original_clip_states[self.dragging_clip.id] if new_track_info: new_track_type, new_track_index = new_track_info @@ -993,19 +995,19 @@ def mouseMoveEvent(self, event: QMouseEvent): self.dragging_clip.track_index = new_track_index delta_x = event.pos().x() - self.drag_start_pos.x() - time_delta = delta_x / self.pixels_per_second - true_new_start_time = original_start_sec + time_delta + time_delta = delta_x / self.pixels_per_ms + true_new_start_time = original_start_ms + time_delta - playhead_time = self.playhead_pos_sec - snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_second + playhead_time = self.playhead_pos_ms + snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms new_start_time = true_new_start_time - true_new_end_time = true_new_start_time + self.dragging_clip.duration_sec + true_new_end_time = true_new_start_time + self.dragging_clip.duration_ms if abs(true_new_start_time - playhead_time) < snap_time_delta: new_start_time = playhead_time elif abs(true_new_end_time - playhead_time) < snap_time_delta: - new_start_time = playhead_time - self.dragging_clip.duration_sec + new_start_time = playhead_time - self.dragging_clip.duration_ms for other_clip in self.timeline.clips: if other_clip.id == self.dragging_clip.id: continue @@ -1014,21 +1016,21 @@ def mouseMoveEvent(self, event: QMouseEvent): other_clip.track_index != self.dragging_clip.track_index): continue - is_overlapping = (new_start_time < other_clip.timeline_end_sec and - new_start_time + self.dragging_clip.duration_sec > other_clip.timeline_start_sec) + is_overlapping = (new_start_time < other_clip.timeline_end_ms and + new_start_time + self.dragging_clip.duration_ms > other_clip.timeline_start_ms) if is_overlapping: - movement_direction = true_new_start_time - original_start_sec + movement_direction = true_new_start_time - original_start_ms if movement_direction > 0: - new_start_time = other_clip.timeline_start_sec - self.dragging_clip.duration_sec + new_start_time = other_clip.timeline_start_ms - self.dragging_clip.duration_ms else: - new_start_time = other_clip.timeline_end_sec + new_start_time = other_clip.timeline_end_ms break final_start_time = max(0, new_start_time) - self.dragging_clip.timeline_start_sec = final_start_time + self.dragging_clip.timeline_start_ms = int(final_start_time) if self.dragging_linked_clip: - self.dragging_linked_clip.timeline_start_sec = final_start_time + self.dragging_linked_clip.timeline_start_ms = int(final_start_time) self.update() @@ -1061,7 +1063,7 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.creating_selection_region = False if self.selection_regions: start, end = self.selection_regions[-1] - if (end - start) * self.pixels_per_second < 2: + if (end - start) * self.pixels_per_ms < 2: self.clear_all_regions() if self.dragging_selection_region: @@ -1071,18 +1073,18 @@ def mouseReleaseEvent(self, event: QMouseEvent): self.dragging_playhead = False if self.dragging_clip: orig_start, orig_track = self.drag_original_clip_states[self.dragging_clip.id] - moved = (orig_start != self.dragging_clip.timeline_start_sec or + moved = (orig_start != self.dragging_clip.timeline_start_ms or orig_track != self.dragging_clip.track_index) if self.dragging_linked_clip: orig_start_link, orig_track_link = self.drag_original_clip_states[self.dragging_linked_clip.id] - moved = moved or (orig_start_link != self.dragging_linked_clip.timeline_start_sec or + moved = moved or (orig_start_link != self.dragging_linked_clip.timeline_start_ms or orig_track_link != self.dragging_linked_clip.track_index) if moved: self.window().finalize_clip_drag(self.drag_start_state) - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline.clips.sort(key=lambda c: c.timeline_start_ms) self.highlighted_track_info = None self.highlighted_ghost_track_info = None self.operation_finished.emit() @@ -1162,12 +1164,12 @@ def dragMoveEvent(self, event): event.acceptProposedAction() - duration = media_props['duration'] + duration_ms = media_props['duration_ms'] media_type = media_props['media_type'] has_audio = media_props['has_audio'] pos = event.position() - start_sec = self.x_to_sec(pos.x()) + start_ms = self.x_to_ms(pos.x()) track_info = self.y_to_track_info(pos.y()) self.drag_over_rect = QRectF() @@ -1187,8 +1189,8 @@ def dragMoveEvent(self, event): else: self.highlighted_track_info = track_info - width = int(duration * self.pixels_per_second) - x = self.sec_to_x(start_sec) + width = int(duration_ms * self.pixels_per_ms) + x = self.ms_to_x(start_ms) video_y, audio_y = -1, -1 @@ -1232,19 +1234,19 @@ def dropEvent(self, event): return pos = event.position() - start_sec = self.x_to_sec(pos.x()) + start_ms = self.x_to_ms(pos.x()) track_info = self.y_to_track_info(pos.y()) if not track_info: event.ignore() return - current_timeline_pos = start_sec + current_timeline_pos = start_ms for file_path in added_files: media_info = main_window.media_properties.get(file_path) if not media_info: continue - duration = media_info['duration'] + duration_ms = media_info['duration_ms'] has_audio = media_info['has_audio'] media_type = media_info['media_type'] @@ -1269,14 +1271,14 @@ def dropEvent(self, event): main_window._add_clip_to_timeline( source_path=file_path, - timeline_start_sec=current_timeline_pos, - duration_sec=duration, + timeline_start_ms=current_timeline_pos, + duration_ms=duration_ms, media_type=media_type, - clip_start_sec=0, + clip_start_ms=0, video_track_index=video_track_idx, audio_track_index=audio_track_idx ) - current_timeline_pos += duration + current_timeline_pos += duration_ms event.acceptProposedAction() return @@ -1286,12 +1288,12 @@ def dropEvent(self, event): json_data = json.loads(mime_data.data('application/x-vnd.video.filepath').data().decode('utf-8')) file_path = json_data['path'] - duration = json_data['duration'] + duration_ms = json_data['duration_ms'] has_audio = json_data['has_audio'] media_type = json_data['media_type'] pos = event.position() - start_sec = self.x_to_sec(pos.x()) + start_ms = self.x_to_ms(pos.x()) track_info = self.y_to_track_info(pos.y()) if not track_info: @@ -1321,10 +1323,10 @@ def dropEvent(self, event): main_window = self.window() main_window._add_clip_to_timeline( source_path=file_path, - timeline_start_sec=start_sec, - duration_sec=duration, + timeline_start_ms=start_ms, + duration_ms=duration_ms, media_type=media_type, - clip_start_sec=0, + clip_start_ms=0, video_track_index=video_track_idx, audio_track_index=audio_track_idx ) @@ -1382,8 +1384,8 @@ def contextMenuEvent(self, event: 'QContextMenuEvent'): split_action = menu.addAction("Split Clip") delete_action = menu.addAction("Delete Clip") - playhead_time = self.playhead_pos_sec - is_playhead_over_clip = (clip_at_pos.timeline_start_sec < playhead_time < clip_at_pos.timeline_end_sec) + playhead_time = self.playhead_pos_ms + is_playhead_over_clip = (clip_at_pos.timeline_start_ms < playhead_time < clip_at_pos.timeline_end_ms) split_action.setEnabled(is_playhead_over_clip) split_action.triggered.connect(lambda: self.split_requested.emit(clip_at_pos)) delete_action.triggered.connect(lambda: self.delete_clip_requested.emit(clip_at_pos)) @@ -1672,7 +1674,7 @@ def startDrag(self, supportedActions): payload = { "path": path, - "duration": media_info['duration'], + "duration_ms": media_info['duration_ms'], "has_audio": media_info['has_audio'], "media_type": media_info['media_type'] } @@ -1930,7 +1932,7 @@ def _toggle_scale_to_fit(self, checked): def on_timeline_changed_by_undo(self): self.prune_empty_tracks() self.timeline_widget.update() - self.seek_preview(self.timeline_widget.playhead_pos_sec) + self.seek_preview(self.timeline_widget.playhead_pos_ms) self.status_label.setText("Operation undone/redone.") def update_undo_redo_actions(self): @@ -1964,8 +1966,8 @@ def on_add_to_timeline_at_playhead(self, file_path): self.status_label.setText(f"Error: Could not find properties for {os.path.basename(file_path)}") return - playhead_pos = self.timeline_widget.playhead_pos_sec - duration = media_info['duration'] + playhead_pos = self.timeline_widget.playhead_pos_ms + duration_ms = media_info['duration_ms'] has_audio = media_info['has_audio'] media_type = media_info['media_type'] @@ -1974,10 +1976,10 @@ def on_add_to_timeline_at_playhead(self, file_path): self._add_clip_to_timeline( source_path=file_path, - timeline_start_sec=playhead_pos, - duration_sec=duration, + timeline_start_ms=playhead_pos, + duration_ms=duration_ms, media_type=media_type, - clip_start_sec=0.0, + clip_start_ms=0, video_track_index=video_track, audio_track_index=audio_track ) @@ -2102,25 +2104,25 @@ def _create_menu_bar(self): data['action'] = action self.windows_menu.addAction(action) - def _get_topmost_video_clip_at(self, time_sec): + def _get_topmost_video_clip_at(self, time_ms): """Finds the video clip on the highest track at a specific time.""" top_clip = None for c in self.timeline.clips: - if c.track_type == 'video' and c.timeline_start_sec <= time_sec < c.timeline_end_sec: + if c.track_type == 'video' and c.timeline_start_ms <= time_ms < c.timeline_end_ms: if top_clip is None or c.track_index > top_clip.track_index: top_clip = c return top_clip - def _start_playback_stream_at(self, time_sec): + def _start_playback_stream_at(self, time_ms): self._stop_playback_stream() - clip = self._get_topmost_video_clip_at(time_sec) + clip = self._get_topmost_video_clip_at(time_ms) if not clip: return self.playback_clip = clip - clip_time = time_sec - clip.timeline_start_sec + clip.clip_start_sec + clip_time_sec = (time_ms - clip.timeline_start_ms + clip.clip_start_ms) / 1000.0 w, h = self.project_width, self.project_height try: - args = (ffmpeg.input(self.playback_clip.source_path, ss=clip_time) + args = (ffmpeg.input(self.playback_clip.source_path, ss=f"{clip_time_sec:.6f}") .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile()) @@ -2149,7 +2151,7 @@ def _get_media_properties(self, file_path): if file_ext in ['.png', '.jpg', '.jpeg']: img = QImage(file_path) media_info['media_type'] = 'image' - media_info['duration'] = 5.0 + media_info['duration_ms'] = 5000 media_info['has_audio'] = False media_info['width'] = img.width() media_info['height'] = img.height() @@ -2160,7 +2162,8 @@ def _get_media_properties(self, file_path): if video_stream: media_info['media_type'] = 'video' - media_info['duration'] = float(video_stream.get('duration', probe['format'].get('duration', 0))) + duration_sec = float(video_stream.get('duration', probe['format'].get('duration', 0))) + media_info['duration_ms'] = int(duration_sec * 1000) media_info['has_audio'] = audio_stream is not None media_info['width'] = int(video_stream['width']) media_info['height'] = int(video_stream['height']) @@ -2169,7 +2172,8 @@ def _get_media_properties(self, file_path): if den > 0: media_info['fps'] = num / den elif audio_stream: media_info['media_type'] = 'audio' - media_info['duration'] = float(audio_stream.get('duration', probe['format'].get('duration', 0))) + duration_sec = float(audio_stream.get('duration', probe['format'].get('duration', 0))) + media_info['duration_ms'] = int(duration_sec * 1000) media_info['has_audio'] = True else: return None @@ -2215,8 +2219,8 @@ def _update_project_properties_from_clip(self, source_path): def _probe_for_drag(self, file_path): return self._get_media_properties(file_path) - def get_frame_data_at_time(self, time_sec): - clip_at_time = self._get_topmost_video_clip_at(time_sec) + def get_frame_data_at_time(self, time_ms): + clip_at_time = self._get_topmost_video_clip_at(time_ms) if not clip_at_time: return (None, 0, 0) try: @@ -2231,10 +2235,10 @@ def get_frame_data_at_time(self, time_sec): .run(capture_stdout=True, quiet=True) ) else: - clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec + clip_time_sec = (time_ms - clip_at_time.timeline_start_ms + clip_at_time.clip_start_ms) / 1000.0 out, _ = ( ffmpeg - .input(clip_at_time.source_path, ss=clip_time) + .input(clip_at_time.source_path, ss=f"{clip_time_sec:.6f}") .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') @@ -2245,8 +2249,8 @@ def get_frame_data_at_time(self, time_sec): print(f"Error extracting frame data: {e.stderr}") return (None, 0, 0) - def get_frame_at_time(self, time_sec): - clip_at_time = self._get_topmost_video_clip_at(time_sec) + def get_frame_at_time(self, time_ms): + clip_at_time = self._get_topmost_video_clip_at(time_ms) if not clip_at_time: return None try: w, h = self.project_width, self.project_height @@ -2259,8 +2263,8 @@ def get_frame_at_time(self, time_sec): .run(capture_stdout=True, quiet=True) ) else: - clip_time = time_sec - clip_at_time.timeline_start_sec + clip_at_time.clip_start_sec - out, _ = (ffmpeg.input(clip_at_time.source_path, ss=clip_time) + clip_time_sec = (time_ms - clip_at_time.timeline_start_ms + clip_at_time.clip_start_ms) / 1000.0 + out, _ = (ffmpeg.input(clip_at_time.source_path, ss=f"{clip_time_sec:.6f}") .filter('scale', w, h, force_original_aspect_ratio='decrease') .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') @@ -2270,11 +2274,11 @@ def get_frame_at_time(self, time_sec): return QPixmap.fromImage(image) except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return None - def seek_preview(self, time_sec): + def seek_preview(self, time_ms): self._stop_playback_stream() - self.timeline_widget.playhead_pos_sec = time_sec + self.timeline_widget.playhead_pos_ms = int(time_ms) self.timeline_widget.update() - self.current_preview_pixmap = self.get_frame_at_time(time_sec) + self.current_preview_pixmap = self.get_frame_at_time(time_ms) self._update_preview_display() def _update_preview_display(self): @@ -2303,26 +2307,26 @@ def toggle_playback(self): if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play") else: if not self.timeline.clips: return - if self.timeline_widget.playhead_pos_sec >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_sec = 0.0 + if self.timeline_widget.playhead_pos_ms >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_ms = 0 self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause") - def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0.0) + def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0) def step_frame(self, direction): if not self.timeline.clips: return self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream() - frame_duration = 1.0 / self.project_fps - new_time = self.timeline_widget.playhead_pos_sec + (direction * frame_duration) - self.seek_preview(max(0, min(new_time, self.timeline.get_total_duration()))) + frame_duration_ms = 1000.0 / self.project_fps + new_time = self.timeline_widget.playhead_pos_ms + (direction * frame_duration_ms) + self.seek_preview(int(max(0, min(new_time, self.timeline.get_total_duration())))) def advance_playback_frame(self): - frame_duration = 1.0 / self.project_fps - new_time = self.timeline_widget.playhead_pos_sec + frame_duration - if new_time > self.timeline.get_total_duration(): self.stop_playback(); return + frame_duration_ms = 1000.0 / self.project_fps + new_time_ms = self.timeline_widget.playhead_pos_ms + frame_duration_ms + if new_time_ms > self.timeline.get_total_duration(): self.stop_playback(); return - self.timeline_widget.playhead_pos_sec = new_time + self.timeline_widget.playhead_pos_ms = round(new_time_ms) self.timeline_widget.update() - clip_at_new_time = self._get_topmost_video_clip_at(new_time) + clip_at_new_time = self._get_topmost_video_clip_at(new_time_ms) if not clip_at_new_time: self._stop_playback_stream() @@ -2334,12 +2338,12 @@ def advance_playback_frame(self): if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: self._stop_playback_stream() self.playback_clip = clip_at_new_time - self.current_preview_pixmap = self.get_frame_at_time(new_time) + self.current_preview_pixmap = self.get_frame_at_time(new_time_ms) self._update_preview_display() return if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: - self._start_playback_stream_at(new_time) + self._start_playback_stream_at(new_time_ms) if self.playback_process: frame_size = self.project_width * self.project_height * 3 @@ -2421,7 +2425,7 @@ def save_project_as(self): if not path: return project_data = { "media_pool": self.media_pool, - "clips": [{"source_path": c.source_path, "timeline_start_sec": c.timeline_start_sec, "clip_start_sec": c.clip_start_sec, "duration_sec": c.duration_sec, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips], + "clips": [{"source_path": c.source_path, "timeline_start_ms": c.timeline_start_ms, "clip_start_ms": c.clip_start_ms, "duration_ms": c.duration_ms, "track_index": c.track_index, "track_type": c.track_type, "media_type": c.media_type, "group_id": c.group_id} for c in self.timeline.clips], "selection_regions": self.timeline_widget.selection_regions, "last_export_path": self.last_export_path, "settings": { @@ -2463,7 +2467,7 @@ def _load_project_from_path(self, path): for clip_data in project_data["clips"]: if not os.path.exists(clip_data["source_path"]): self.status_label.setText(f"Error: Missing media file {clip_data['source_path']}"); self.new_project(); return - + if 'media_type' not in clip_data: ext = os.path.splitext(clip_data['source_path'])[1].lower() if ext in ['.mp3', '.wav', '.m4a', '.aac']: @@ -2552,19 +2556,19 @@ def add_media_to_timeline(self): if not added_files: return - playhead_pos = self.timeline_widget.playhead_pos_sec + playhead_pos = self.timeline_widget.playhead_pos_ms def add_clips_action(): for file_path in added_files: media_info = self.media_properties.get(file_path) if not media_info: continue - duration = media_info['duration'] + duration_ms = media_info['duration_ms'] media_type = media_info['media_type'] has_audio = media_info['has_audio'] clip_start_time = playhead_pos - clip_end_time = playhead_pos + duration + clip_end_time = playhead_pos + duration_ms video_track_index = None audio_track_index = None @@ -2572,7 +2576,7 @@ def add_clips_action(): if media_type in ['video', 'image']: for i in range(1, self.timeline.num_video_tracks + 2): is_occupied = any( - c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time + c.timeline_start_ms < clip_end_time and c.timeline_end_ms > clip_start_time for c in self.timeline.clips if c.track_type == 'video' and c.track_index == i ) if not is_occupied: @@ -2582,7 +2586,7 @@ def add_clips_action(): if has_audio: for i in range(1, self.timeline.num_audio_tracks + 2): is_occupied = any( - c.timeline_start_sec < clip_end_time and c.timeline_end_sec > clip_start_time + c.timeline_start_ms < clip_end_time and c.timeline_end_ms > clip_start_time for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i ) if not is_occupied: @@ -2593,13 +2597,13 @@ def add_clips_action(): if video_track_index is not None: if video_track_index > self.timeline.num_video_tracks: self.timeline.num_video_tracks = video_track_index - video_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, video_track_index, 'video', media_type, group_id) + video_clip = TimelineClip(file_path, clip_start_time, 0, duration_ms, video_track_index, 'video', media_type, group_id) self.timeline.add_clip(video_clip) if audio_track_index is not None: if audio_track_index > self.timeline.num_audio_tracks: self.timeline.num_audio_tracks = audio_track_index - audio_clip = TimelineClip(file_path, clip_start_time, 0.0, duration, audio_track_index, 'audio', media_type, group_id) + audio_clip = TimelineClip(file_path, clip_start_time, 0, duration_ms, audio_track_index, 'audio', media_type, group_id) self.timeline.add_clip(audio_clip) self.status_label.setText(f"Added {len(added_files)} file(s) to timeline.") @@ -2611,7 +2615,7 @@ def add_media_files(self): if file_paths: self._add_media_files_to_project(file_paths) - def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, media_type, clip_start_sec=0.0, video_track_index=None, audio_track_index=None): + def _add_clip_to_timeline(self, source_path, timeline_start_ms, duration_ms, media_type, clip_start_ms=0, video_track_index=None, audio_track_index=None): if media_type in ['video', 'image']: self._update_project_properties_from_clip(source_path) @@ -2621,13 +2625,13 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, m if video_track_index is not None: if video_track_index > self.timeline.num_video_tracks: self.timeline.num_video_tracks = video_track_index - video_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, video_track_index, 'video', media_type, group_id) + video_clip = TimelineClip(source_path, timeline_start_ms, clip_start_ms, duration_ms, video_track_index, 'video', media_type, group_id) self.timeline.add_clip(video_clip) if audio_track_index is not None: if audio_track_index > self.timeline.num_audio_tracks: self.timeline.num_audio_tracks = audio_track_index - audio_clip = TimelineClip(source_path, timeline_start_sec, clip_start_sec, duration_sec, audio_track_index, 'audio', media_type, group_id) + audio_clip = TimelineClip(source_path, timeline_start_ms, clip_start_ms, duration_ms, audio_track_index, 'audio', media_type, group_id) self.timeline.add_clip(audio_clip) new_state = self._get_current_timeline_state() @@ -2635,21 +2639,21 @@ def _add_clip_to_timeline(self, source_path, timeline_start_sec, duration_sec, m command.undo() self.undo_stack.push(command) - def _split_at_time(self, clip_to_split, time_sec, new_group_id=None): - if not (clip_to_split.timeline_start_sec < time_sec < clip_to_split.timeline_end_sec): return False - split_point = time_sec - clip_to_split.timeline_start_sec - orig_dur = clip_to_split.duration_sec + def _split_at_time(self, clip_to_split, time_ms, new_group_id=None): + if not (clip_to_split.timeline_start_ms < time_ms < clip_to_split.timeline_end_ms): return False + split_point = time_ms - clip_to_split.timeline_start_ms + orig_dur = clip_to_split.duration_ms group_id_for_new_clip = new_group_id if new_group_id is not None else clip_to_split.group_id - new_clip = TimelineClip(clip_to_split.source_path, time_sec, clip_to_split.clip_start_sec + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, clip_to_split.media_type, group_id_for_new_clip) - clip_to_split.duration_sec = split_point + new_clip = TimelineClip(clip_to_split.source_path, time_ms, clip_to_split.clip_start_ms + split_point, orig_dur - split_point, clip_to_split.track_index, clip_to_split.track_type, clip_to_split.media_type, group_id_for_new_clip) + clip_to_split.duration_ms = split_point self.timeline.add_clip(new_clip) return True def split_clip_at_playhead(self, clip_to_split=None): - playhead_time = self.timeline_widget.playhead_pos_sec + playhead_time = self.timeline_widget.playhead_pos_ms if not clip_to_split: - clips_at_playhead = [c for c in self.timeline.clips if c.timeline_start_sec < playhead_time < c.timeline_end_sec] + clips_at_playhead = [c for c in self.timeline.clips if c.timeline_start_ms < playhead_time < c.timeline_end_ms] if not clips_at_playhead: self.status_label.setText("Playhead is not over a clip to split.") return @@ -2723,20 +2727,20 @@ def action(): return target_audio_track = 1 - new_audio_start = video_clip.timeline_start_sec - new_audio_end = video_clip.timeline_end_sec + new_audio_start = video_clip.timeline_start_ms + new_audio_end = video_clip.timeline_end_ms conflicting_clips = [ c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == target_audio_track and - c.timeline_start_sec < new_audio_end and c.timeline_end_sec > new_audio_start + c.timeline_start_ms < new_audio_end and c.timeline_end_ms > new_audio_start ] for conflict_clip in conflicting_clips: found_spot = False for check_track_idx in range(target_audio_track + 1, self.timeline.num_audio_tracks + 2): is_occupied = any( - other.timeline_start_sec < conflict_clip.timeline_end_sec and other.timeline_end_sec > conflict_clip.timeline_start_sec + other.timeline_start_ms < conflict_clip.timeline_end_ms and other.timeline_end_ms > conflict_clip.timeline_start_ms for other in self.timeline.clips if other.id != conflict_clip.id and other.track_type == 'audio' and other.track_index == check_track_idx ) @@ -2751,9 +2755,9 @@ def action(): new_audio_clip = TimelineClip( source_path=video_clip.source_path, - timeline_start_sec=video_clip.timeline_start_sec, - clip_start_sec=video_clip.clip_start_sec, - duration_sec=video_clip.duration_sec, + timeline_start_ms=video_clip.timeline_start_ms, + clip_start_ms=video_clip.clip_start_ms, + duration_ms=video_clip.duration_ms, track_index=target_audio_track, track_type='audio', media_type=video_clip.media_type, @@ -2778,10 +2782,10 @@ def _perform_complex_timeline_change(self, description, change_function): def on_split_region(self, region): def action(): - start_sec, end_sec = region + start_ms, end_ms = region clips = list(self.timeline.clips) - for clip in clips: self._split_at_time(clip, end_sec) - for clip in clips: self._split_at_time(clip, start_sec) + for clip in clips: self._split_at_time(clip, end_ms) + for clip in clips: self._split_at_time(clip, start_ms) self.timeline_widget.clear_region(region) self._perform_complex_timeline_change("Split Region", action) @@ -2793,7 +2797,7 @@ def action(): split_points.add(end) for point in sorted(list(split_points)): - group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: @@ -2803,98 +2807,98 @@ def action(): def on_join_region(self, region): def action(): - start_sec, end_sec = region - duration_to_remove = end_sec - start_sec - if duration_to_remove <= 0.01: return + start_ms, end_ms = region + duration_to_remove = end_ms - start_ms + if duration_to_remove <= 10: return - for point in [start_sec, end_sec]: - group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + for point in [start_ms, end_ms]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms] for clip in clips_to_remove: self.timeline.clips.remove(clip) for clip in self.timeline.clips: - if clip.timeline_start_sec >= end_sec: - clip.timeline_start_sec -= duration_to_remove + if clip.timeline_start_ms >= end_ms: + clip.timeline_start_ms -= duration_to_remove - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline.clips.sort(key=lambda c: c.timeline_start_ms) self.timeline_widget.clear_region(region) self._perform_complex_timeline_change("Join Region", action) def on_join_all_regions(self, regions): def action(): for region in sorted(regions, key=lambda r: r[0], reverse=True): - start_sec, end_sec = region - duration_to_remove = end_sec - start_sec - if duration_to_remove <= 0.01: continue + start_ms, end_ms = region + duration_to_remove = end_ms - start_ms + if duration_to_remove <= 10: continue - for point in [start_sec, end_sec]: - group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + for point in [start_ms, end_ms]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms] for clip in clips_to_remove: try: self.timeline.clips.remove(clip) except ValueError: pass for clip in self.timeline.clips: - if clip.timeline_start_sec >= end_sec: - clip.timeline_start_sec -= duration_to_remove + if clip.timeline_start_ms >= end_ms: + clip.timeline_start_ms -= duration_to_remove - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline.clips.sort(key=lambda c: c.timeline_start_ms) self.timeline_widget.clear_all_regions() self._perform_complex_timeline_change("Join All Regions", action) def on_delete_region(self, region): def action(): - start_sec, end_sec = region - duration_to_remove = end_sec - start_sec - if duration_to_remove <= 0.01: return + start_ms, end_ms = region + duration_to_remove = end_ms - start_ms + if duration_to_remove <= 10: return - for point in [start_sec, end_sec]: - group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + for point in [start_ms, end_ms]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms] for clip in clips_to_remove: self.timeline.clips.remove(clip) for clip in self.timeline.clips: - if clip.timeline_start_sec >= end_sec: - clip.timeline_start_sec -= duration_to_remove + if clip.timeline_start_ms >= end_ms: + clip.timeline_start_ms -= duration_to_remove - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline.clips.sort(key=lambda c: c.timeline_start_ms) self.timeline_widget.clear_region(region) self._perform_complex_timeline_change("Delete Region", action) def on_delete_all_regions(self, regions): def action(): for region in sorted(regions, key=lambda r: r[0], reverse=True): - start_sec, end_sec = region - duration_to_remove = end_sec - start_sec - if duration_to_remove <= 0.01: continue + start_ms, end_ms = region + duration_to_remove = end_ms - start_ms + if duration_to_remove <= 10: continue - for point in [start_sec, end_sec]: - group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_sec < point < c.timeline_end_sec} + for point in [start_ms, end_ms]: + group_ids_at_point = {c.group_id for c in self.timeline.clips if c.timeline_start_ms < point < c.timeline_end_ms} new_group_ids = {gid: str(uuid.uuid4()) for gid in group_ids_at_point} for clip in list(self.timeline.clips): if clip.group_id in new_group_ids: self._split_at_time(clip, point, new_group_ids[clip.group_id]) - clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_sec >= start_sec and c.timeline_start_sec < end_sec] + clips_to_remove = [c for c in self.timeline.clips if c.timeline_start_ms >= start_ms and c.timeline_start_ms < end_ms] for clip in clips_to_remove: try: self.timeline.clips.remove(clip) except ValueError: pass for clip in self.timeline.clips: - if clip.timeline_start_sec >= end_sec: - clip.timeline_start_sec -= duration_to_remove + if clip.timeline_start_ms >= end_ms: + clip.timeline_start_ms -= duration_to_remove - self.timeline.clips.sort(key=lambda c: c.timeline_start_sec) + self.timeline.clips.sort(key=lambda c: c.timeline_start_ms) self.timeline_widget.clear_all_regions() self._perform_complex_timeline_change("Delete All Regions", action) @@ -2936,10 +2940,12 @@ def export_video(self): self.last_export_path = output_path - w, h, fr_str, total_dur = self.project_width, self.project_height, str(self.project_fps), self.timeline.get_total_duration() + total_dur_ms = self.timeline.get_total_duration() + total_dur_sec = total_dur_ms / 1000.0 + w, h, fr_str = self.project_width, self.project_height, str(self.project_fps) sample_rate, channel_layout = '44100', 'stereo' - video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur}', f='lavfi') + video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur_sec}', f='lavfi') all_video_clips = sorted( [c for c in self.timeline.clips if c.track_type == 'video'], @@ -2954,37 +2960,45 @@ def export_video(self): input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) clip_source_node = input_nodes[clip.source_path] + clip_duration_sec = clip.duration_ms / 1000.0 + clip_start_sec = clip.clip_start_ms / 1000.0 + timeline_start_sec = clip.timeline_start_ms / 1000.0 + timeline_end_sec = clip.timeline_end_ms / 1000.0 + if clip.media_type == 'image': - segment_stream = (clip_source_node.video.trim(duration=clip.duration_sec).setpts('PTS-STARTPTS')) + segment_stream = (clip_source_node.video.trim(duration=clip_duration_sec).setpts('PTS-STARTPTS')) else: - segment_stream = (clip_source_node.video.trim(start=clip.clip_start_sec, duration=clip.duration_sec).setpts('PTS-STARTPTS')) + segment_stream = (clip_source_node.video.trim(start=clip_start_sec, duration=clip_duration_sec).setpts('PTS-STARTPTS')) processed_segment = (segment_stream.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')) - video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{clip.timeline_start_sec},{clip.timeline_end_sec})') + video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{timeline_start_sec},{timeline_end_sec})') final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps) track_audio_streams = [] for i in range(self.timeline.num_audio_tracks): - track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_sec) + track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_ms) if not track_clips: continue track_segments = [] - last_end = track_clips[0].timeline_start_sec + last_end_ms = track_clips[0].timeline_start_ms for clip in track_clips: if clip.source_path not in input_nodes: input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) - gap = clip.timeline_start_sec - last_end - if gap > 0.01: - track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap}', f='lavfi')) - a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip.clip_start_sec, duration=clip.duration_sec).filter('asetpts', 'PTS-STARTPTS') + gap_ms = clip.timeline_start_ms - last_end_ms + if gap_ms > 10: + track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap_ms/1000.0}', f='lavfi')) + + clip_start_sec = clip.clip_start_ms / 1000.0 + clip_duration_sec = clip.duration_ms / 1000.0 + a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip_start_sec, duration=clip_duration_sec).filter('asetpts', 'PTS-STARTPTS') track_segments.append(a_seg) - last_end = clip.timeline_end_sec - track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_sec * 1000)}ms', all=True)) + last_end_ms = clip.timeline_end_ms + track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_ms)}ms', all=True)) if track_audio_streams: final_audio = ffmpeg.filter(track_audio_streams, 'amix', inputs=len(track_audio_streams), duration='longest') else: - final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur}', f='lavfi') + final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur_sec}', f='lavfi') output_args = {'vcodec': settings['vcodec'], 'acodec': settings['acodec'], 'pix_fmt': 'yuv420p'} if settings['v_bitrate']: output_args['b:v'] = settings['v_bitrate'] @@ -2994,7 +3008,7 @@ def export_video(self): ffmpeg_cmd = ffmpeg.output(final_video, final_audio, output_path, **output_args).overwrite_output().compile() self.progress_bar.setVisible(True); self.progress_bar.setValue(0); self.status_label.setText("Exporting...") self.export_thread = QThread() - self.export_worker = ExportWorker(ffmpeg_cmd, total_dur) + self.export_worker = ExportWorker(ffmpeg_cmd, total_dur_ms) self.export_worker.moveToThread(self.export_thread) self.export_thread.started.connect(self.export_worker.run_export) self.export_worker.finished.connect(self.on_export_finished) From 75d29f9b48bec16b0e3468af28d37d52676a98be Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 21:35:16 +1100 Subject: [PATCH 142/155] fix timescale shifts --- videoeditor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/videoeditor.py b/videoeditor.py index 5126b79b5..632ca4069 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -468,7 +468,7 @@ def draw_timescale(self, painter): frame_dur_ms = 1000.0 / self.project_fps intervals_ms = [ - int(frame_dur_ms), int(2*frame_dur_ms), int(5*frame_dur_ms), int(10*frame_dur_ms), + frame_dur_ms, 2*frame_dur_ms, 5*frame_dur_ms, 10*frame_dur_ms, 100, 200, 500, 1000, 2000, 5000, 10000, 15000, 30000, 60000, 120000, 300000, 600000, 900000, 1800000, 3600000, 2*3600000, 5*3600000, 10*3600000 From fde093f128f408b0b3827f0f0a562df8dcd9272c Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 21:50:44 +1100 Subject: [PATCH 143/155] add buttons to snap to nearest clip collision --- videoeditor.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/videoeditor.py b/videoeditor.py index 632ca4069..98d908477 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -1853,11 +1853,15 @@ def _setup_ui(self): self.stop_button = QPushButton("Stop") self.frame_back_button = QPushButton("<") self.frame_forward_button = QPushButton(">") + self.snap_back_button = QPushButton("|<") + self.snap_forward_button = QPushButton(">|") controls_layout.addStretch() + controls_layout.addWidget(self.snap_back_button) controls_layout.addWidget(self.frame_back_button) controls_layout.addWidget(self.play_pause_button) controls_layout.addWidget(self.stop_button) controls_layout.addWidget(self.frame_forward_button) + controls_layout.addWidget(self.snap_forward_button) controls_layout.addStretch() main_layout.addWidget(controls_widget) @@ -1908,6 +1912,8 @@ def _connect_signals(self): self.stop_button.clicked.connect(self.stop_playback) self.frame_back_button.clicked.connect(lambda: self.step_frame(-1)) self.frame_forward_button.clicked.connect(lambda: self.step_frame(1)) + self.snap_back_button.clicked.connect(lambda: self.snap_playhead(-1)) + self.snap_forward_button.clicked.connect(lambda: self.snap_playhead(1)) self.playback_timer.timeout.connect(self.advance_playback_frame) self.project_media_widget.add_media_requested.connect(self.add_media_files) @@ -2318,6 +2324,32 @@ def step_frame(self, direction): new_time = self.timeline_widget.playhead_pos_ms + (direction * frame_duration_ms) self.seek_preview(int(max(0, min(new_time, self.timeline.get_total_duration())))) + def snap_playhead(self, direction): + if not self.timeline.clips: + return + + current_time_ms = self.timeline_widget.playhead_pos_ms + + snap_points = set() + for clip in self.timeline.clips: + snap_points.add(clip.timeline_start_ms) + snap_points.add(clip.timeline_end_ms) + + sorted_points = sorted(list(snap_points)) + + TOLERANCE_MS = 1 + + if direction == 1: # Forward + next_points = [p for p in sorted_points if p > current_time_ms + TOLERANCE_MS] + if next_points: + self.seek_preview(next_points[0]) + elif direction == -1: # Backward + prev_points = [p for p in sorted_points if p < current_time_ms - TOLERANCE_MS] + if prev_points: + self.seek_preview(prev_points[-1]) + elif current_time_ms > 0: + self.seek_preview(0) + def advance_playback_frame(self): frame_duration_ms = 1000.0 / self.project_fps new_time_ms = self.timeline_widget.playhead_pos_ms + frame_duration_ms From 78f2a959ab12d56a24bcd1a51f916ee5057b01bc Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Sun, 12 Oct 2025 22:13:19 +1100 Subject: [PATCH 144/155] hold shift for universal frame by frame snapping on all operations --- videoeditor.py | 91 ++++++++++++++++++++++++++++++++++---------------- 1 file changed, 63 insertions(+), 28 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index 98d908477..53f3baa85 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -636,11 +636,17 @@ def y_to_track_info(self, y): return None + def _snap_to_frame(self, time_ms): + frame_duration_ms = 1000.0 / self.project_fps + if frame_duration_ms <= 0: + return int(time_ms) + frame_number = round(time_ms / frame_duration_ms) + return int(frame_number * frame_duration_ms) + def _snap_time_if_needed(self, time_ms): frame_duration_ms = 1000.0 / self.project_fps if frame_duration_ms > 0 and frame_duration_ms * self.pixels_per_ms > 4: - frame_number = round(time_ms / frame_duration_ms) - return int(frame_number * frame_duration_ms) + return self._snap_to_frame(time_ms) return int(time_ms) def get_region_at_pos(self, pos: QPoint): @@ -791,15 +797,24 @@ def mousePressEvent(self, event: QMouseEvent): if is_in_track_area: self.creating_selection_region = True - playhead_x = self.ms_to_x(self.playhead_pos_ms) - if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS: - self.selection_drag_start_ms = self.playhead_pos_ms + is_shift_pressed = bool(event.modifiers() & Qt.KeyboardModifier.ShiftModifier) + start_ms = self.x_to_ms(event.pos().x()) + + if is_shift_pressed: + self.selection_drag_start_ms = self._snap_to_frame(start_ms) else: - self.selection_drag_start_ms = self.x_to_ms(event.pos().x()) + playhead_x = self.ms_to_x(self.playhead_pos_ms) + if abs(event.pos().x() - playhead_x) < self.SNAP_THRESHOLD_PIXELS: + self.selection_drag_start_ms = self.playhead_pos_ms + else: + self.selection_drag_start_ms = start_ms self.selection_regions.append([self.selection_drag_start_ms, self.selection_drag_start_ms]) elif is_on_timescale: time_ms = max(0, self.x_to_ms(event.pos().x())) - self.playhead_pos_ms = self._snap_time_if_needed(time_ms) + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + self.playhead_pos_ms = self._snap_to_frame(time_ms) + else: + self.playhead_pos_ms = self._snap_time_if_needed(time_ms) self.playhead_moved.emit(self.playhead_pos_ms) self.dragging_playhead = True @@ -816,6 +831,8 @@ def mouseMoveEvent(self, event: QMouseEvent): if self.resizing_selection_region: current_ms = max(0, self.x_to_ms(event.pos().x())) + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + current_ms = self._snap_to_frame(current_ms) original_start, original_end = self.resize_selection_start_values if self.resize_selection_edge == 'left': @@ -836,6 +853,7 @@ def mouseMoveEvent(self, event: QMouseEvent): self.update() return if self.resizing_clip: + is_shift_pressed = bool(event.modifiers() & Qt.KeyboardModifier.ShiftModifier) linked_clip = next((c for c in self.timeline.clips if c.group_id == self.resizing_clip.group_id and c.id != self.resizing_clip.id), None) delta_x = event.pos().x() - self.resize_start_pos.x() time_delta = delta_x / self.pixels_per_ms @@ -858,11 +876,14 @@ def mouseMoveEvent(self, event: QMouseEvent): original_clip_start = self.drag_start_state[0][[c.id for c in self.drag_start_state[0]].index(self.resizing_clip.id)].clip_start_ms true_new_start_ms = original_start + time_delta - new_start_ms = true_new_start_ms - for snap_point in snap_points: - if abs(true_new_start_ms - snap_point) < snap_time_delta: - new_start_ms = snap_point - break + if is_shift_pressed: + new_start_ms = self._snap_to_frame(true_new_start_ms) + else: + new_start_ms = true_new_start_ms + for snap_point in snap_points: + if abs(true_new_start_ms - snap_point) < snap_time_delta: + new_start_ms = snap_point + break if new_start_ms > original_start + original_duration - min_duration_ms: new_start_ms = original_start + original_duration - min_duration_ms @@ -896,11 +917,14 @@ def mouseMoveEvent(self, event: QMouseEvent): true_new_duration = original_duration + time_delta true_new_end_time = original_start + true_new_duration - new_end_time = true_new_end_time - for snap_point in snap_points: - if abs(true_new_end_time - snap_point) < snap_time_delta: - new_end_time = snap_point - break + if is_shift_pressed: + new_end_time = self._snap_to_frame(true_new_end_time) + else: + new_end_time = true_new_end_time + for snap_point in snap_points: + if abs(true_new_end_time - snap_point) < snap_time_delta: + new_end_time = snap_point + break new_duration = new_end_time - original_start @@ -948,6 +972,8 @@ def mouseMoveEvent(self, event: QMouseEvent): if self.creating_selection_region: current_ms = self.x_to_ms(event.pos().x()) + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + current_ms = self._snap_to_frame(current_ms) start = min(self.selection_drag_start_ms, current_ms) end = max(self.selection_drag_start_ms, current_ms) self.selection_regions[-1] = [start, end] @@ -962,6 +988,9 @@ def mouseMoveEvent(self, event: QMouseEvent): duration = original_end - original_start new_start = max(0, original_start + time_delta) + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + new_start = self._snap_to_frame(new_start) + self.dragging_selection_region[0] = new_start self.dragging_selection_region[1] = new_start + duration @@ -970,7 +999,10 @@ def mouseMoveEvent(self, event: QMouseEvent): if self.dragging_playhead: time_ms = max(0, self.x_to_ms(event.pos().x())) - self.playhead_pos_ms = self._snap_time_if_needed(time_ms) + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + self.playhead_pos_ms = self._snap_to_frame(time_ms) + else: + self.playhead_pos_ms = self._snap_time_if_needed(time_ms) self.playhead_moved.emit(self.playhead_pos_ms) self.update() elif self.dragging_clip: @@ -997,17 +1029,20 @@ def mouseMoveEvent(self, event: QMouseEvent): delta_x = event.pos().x() - self.drag_start_pos.x() time_delta = delta_x / self.pixels_per_ms true_new_start_time = original_start_ms + time_delta - - playhead_time = self.playhead_pos_ms - snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms - new_start_time = true_new_start_time - true_new_end_time = true_new_start_time + self.dragging_clip.duration_ms - - if abs(true_new_start_time - playhead_time) < snap_time_delta: - new_start_time = playhead_time - elif abs(true_new_end_time - playhead_time) < snap_time_delta: - new_start_time = playhead_time - self.dragging_clip.duration_ms + if event.modifiers() & Qt.KeyboardModifier.ShiftModifier: + new_start_time = self._snap_to_frame(true_new_start_time) + else: + playhead_time = self.playhead_pos_ms + snap_time_delta = self.SNAP_THRESHOLD_PIXELS / self.pixels_per_ms + + new_start_time = true_new_start_time + true_new_end_time = true_new_start_time + self.dragging_clip.duration_ms + + if abs(true_new_start_time - playhead_time) < snap_time_delta: + new_start_time = playhead_time + elif abs(true_new_end_time - playhead_time) < snap_time_delta: + new_start_time = playhead_time - self.dragging_clip.duration_ms for other_clip in self.timeline.clips: if other_clip.id == self.dragging_clip.id: continue From 19b6ae344e4c15c7d08eccbc3745ead094bba080 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 13 Oct 2025 00:09:11 +1100 Subject: [PATCH 145/155] improve timescale --- videoeditor.py | 64 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index 53f3baa85..e6176bec9 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -439,23 +439,69 @@ def draw_headers(self, painter): painter.restore() - def _format_timecode(self, total_ms): + + def _format_timecode(self, total_ms, interval_ms): if abs(total_ms) < 1: total_ms = 0 sign = "-" if total_ms < 0 else "" total_ms = abs(total_ms) seconds = total_ms / 1000.0 + + # Frame-based formatting for high zoom + is_frame_based = interval_ms < (1000.0 / self.project_fps) * 5 + if is_frame_based: + total_frames = int(round(seconds * self.project_fps)) + fps_int = int(round(self.project_fps)) + if fps_int == 0: fps_int = 25 # Avoid division by zero + s_frames = total_frames % fps_int + total_seconds_from_frames = total_frames // fps_int + h_fr = total_seconds_from_frames // 3600 + m_fr = (total_seconds_from_frames % 3600) // 60 + s_fr = total_seconds_from_frames % 60 + + if h_fr > 0: return f"{sign}{h_fr}:{m_fr:02d}:{s_fr:02d}:{s_frames:02d}" + if m_fr > 0: return f"{sign}{m_fr}:{s_fr:02d}:{s_frames:02d}" + return f"{sign}{s_fr}:{s_frames:02d}" + + # Sub-second formatting + if interval_ms < 1000: + precision = 2 if interval_ms < 100 else 1 + s_float = seconds % 60 + m = int((seconds % 3600) / 60) + h = int(seconds / 3600) + + if h > 0: return f"{sign}{h}:{m:02d}:{s_float:0{4+precision}.{precision}f}" + if m > 0: return f"{sign}{m}:{s_float:0{4+precision}.{precision}f}" + + val = f"{s_float:.{precision}f}" + if '.' in val: val = val.rstrip('0').rstrip('.') + return f"{sign}{val}s" + + # Seconds and M:SS formatting + if interval_ms < 60000: + rounded_seconds = int(round(seconds)) + h = rounded_seconds // 3600 + m = (rounded_seconds % 3600) // 60 + s = rounded_seconds % 60 + + if h > 0: + return f"{sign}{h}:{m:02d}:{s:02d}" + + # Use "Xs" format for times under a minute + if rounded_seconds < 60: + return f"{sign}{rounded_seconds}s" + + # Use "M:SS" format for times over a minute + return f"{sign}{m}:{s:02d}" + + # Minute and Hh:MMm formatting for low zoom h = int(seconds / 3600) m = int((seconds % 3600) / 60) - s = seconds % 60 - + if h > 0: return f"{sign}{h}h:{m:02d}m" - if m > 0 or seconds >= 59.99: return f"{sign}{m}m:{int(round(s)):02d}s" - precision = 2 if s < 1 else 1 if s < 10 else 0 - val = f"{s:.{precision}f}" - if '.' in val: val = val.rstrip('0').rstrip('.') - return f"{sign}{val}s" + total_minutes = int(seconds / 60) + return f"{sign}{total_minutes}m" def draw_timescale(self, painter): painter.save() @@ -510,7 +556,7 @@ def draw_ticks(interval_ms, height): if x > self.width() + 50: break if x >= self.HEADER_WIDTH - 50: painter.drawLine(x, self.TIMESCALE_HEIGHT - 12, x, self.TIMESCALE_HEIGHT) - label = self._format_timecode(t_ms) + label = self._format_timecode(t_ms, major_interval) label_width = font_metrics.horizontalAdvance(label) label_x = x - label_width // 2 if label_x < self.HEADER_WIDTH: From 4452582d6eb8d22b3b21ca91a5bda4d7ff01ec3e Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Mon, 13 Oct 2025 01:23:09 +1100 Subject: [PATCH 146/155] clamp timescale conversions --- videoeditor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/videoeditor.py b/videoeditor.py index e6176bec9..d6e744ebe 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -330,7 +330,7 @@ def set_project_fps(self, fps): self.pixels_per_ms = min(self.pixels_per_ms, self.max_pixels_per_ms) self.update() - def ms_to_x(self, ms): return self.HEADER_WIDTH + int((ms - self.view_start_ms) * self.pixels_per_ms) + def ms_to_x(self, ms): return self.HEADER_WIDTH + int(max(-500_000_000, min((ms - self.view_start_ms) * self.pixels_per_ms, 500_000_000))) def x_to_ms(self, x): return self.view_start_ms + int(float(x - self.HEADER_WIDTH) / self.pixels_per_ms) if x > self.HEADER_WIDTH and self.pixels_per_ms > 0 else self.view_start_ms def paintEvent(self, event): @@ -1861,7 +1861,7 @@ def __init__(self, project_to_load=None): self.plugin_manager.discover_and_load_plugins() # Project settings with defaults - self.project_fps = 25.0 + self.project_fps = 50.0 self.project_width = 1280 self.project_height = 720 From baad78fd629f4bb5cce39625c5b5d1f1051fb0bc Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Tue, 14 Oct 2025 19:00:24 +1100 Subject: [PATCH 147/155] added audio for playback, major overhaul of encoder and playback system to use multithreading and better layering, move wgp import out of video player and into plugin --- encoding.py | 210 +++++++++++++++ playback.py | 531 +++++++++++++++++++++++++++++++++++++ plugins.py | 53 ++-- plugins/wan2gp/main.py | 179 ++++++++----- plugins/wan2gp/plugin.json | 5 + videoeditor.py | 380 +++++++++----------------- 6 files changed, 1022 insertions(+), 336 deletions(-) create mode 100644 encoding.py create mode 100644 playback.py create mode 100644 plugins/wan2gp/plugin.json diff --git a/encoding.py b/encoding.py new file mode 100644 index 000000000..22f67489c --- /dev/null +++ b/encoding.py @@ -0,0 +1,210 @@ +import ffmpeg +import subprocess +import re +from PyQt6.QtCore import QObject, pyqtSignal, QThread + +class _ExportRunner(QObject): + progress = pyqtSignal(int) + finished = pyqtSignal(bool, str) + + def __init__(self, ffmpeg_cmd, total_duration_ms, parent=None): + super().__init__(parent) + self.ffmpeg_cmd = ffmpeg_cmd + self.total_duration_ms = total_duration_ms + self.process = None + + def run(self): + try: + startupinfo = None + if hasattr(subprocess, 'STARTUPINFO'): + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + + self.process = subprocess.Popen( + self.ffmpeg_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=True, + encoding="utf-8", + errors='ignore', + startupinfo=startupinfo + ) + + full_output = [] + time_pattern = re.compile(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})") + + for line in iter(self.process.stdout.readline, ""): + full_output.append(line) + match = time_pattern.search(line) + if match: + h, m, s, cs = [int(g) for g in match.groups()] + processed_ms = (h * 3600 + m * 60 + s) * 1000 + cs * 10 + if self.total_duration_ms > 0: + percentage = int((processed_ms / self.total_duration_ms) * 100) + self.progress.emit(min(100, percentage)) + + self.process.stdout.close() + return_code = self.process.wait() + + if return_code == 0: + self.progress.emit(100) + self.finished.emit(True, "Export completed successfully!") + else: + print("--- FFmpeg Export FAILED ---") + print("Command: " + " ".join(self.ffmpeg_cmd)) + print("".join(full_output)) + self.finished.emit(False, f"Export failed with code {return_code}. Check console.") + + except FileNotFoundError: + self.finished.emit(False, "Export failed: ffmpeg.exe not found in your system's PATH.") + except Exception as e: + self.finished.emit(False, f"An exception occurred during export: {e}") + + def get_process(self): + return self.process + +class Encoder(QObject): + progress = pyqtSignal(int) + finished = pyqtSignal(bool, str) + + def __init__(self, parent=None): + super().__init__(parent) + self.worker_thread = None + self.worker = None + self._is_running = False + + def start_export(self, timeline, project_settings, export_settings): + if self._is_running: + self.finished.emit(False, "An export is already in progress.") + return + + self._is_running = True + + try: + total_dur_ms = timeline.get_total_duration() + total_dur_sec = total_dur_ms / 1000.0 + w, h, fps = project_settings['width'], project_settings['height'], project_settings['fps'] + sample_rate, channel_layout = '44100', 'stereo' + + # --- VIDEO GRAPH CONSTRUCTION (Definitive Solution) --- + + all_video_clips = sorted( + [c for c in timeline.clips if c.track_type == 'video'], + key=lambda c: c.track_index + ) + + # 1. Start with a base black canvas that defines the project's duration. + # This is our master clock and bottom layer. + final_video = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fps}:d={total_dur_sec}', f='lavfi') + + # 2. Process each clip individually and overlay it. + for clip in all_video_clips: + # A new, separate input for every single clip guarantees no conflicts. + if clip.media_type == 'image': + clip_input = ffmpeg.input(clip.source_path, loop=1, framerate=fps) + else: + clip_input = ffmpeg.input(clip.source_path) + + # a) Calculate the time shift to align the clip's content with the timeline. + # This ensures the correct frame of the source is shown at the start of the clip. + timeline_start_sec = clip.timeline_start_ms / 1000.0 + clip_start_sec = clip.clip_start_ms / 1000.0 + time_shift_sec = timeline_start_sec - clip_start_sec + + # b) Prepare the layer: apply the time shift, then scale and pad. + timed_layer = ( + clip_input.video + .setpts(f'PTS+{time_shift_sec}/TB') + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + ) + + # c) Define the visibility window for the overlay on the master timeline. + timeline_end_sec = (clip.timeline_start_ms + clip.duration_ms) / 1000.0 + enable_expression = f'between(t,{timeline_start_sec:.6f},{timeline_end_sec:.6f})' + + # d) Overlay the prepared layer onto the composition, enabling it only during its time window. + # eof_action='pass' handles finite streams gracefully. + final_video = ffmpeg.overlay(final_video, timed_layer, enable=enable_expression, eof_action='pass') + + # 3. Set final output format and framerate. + final_video = final_video.filter('format', pix_fmts='yuv420p').filter('fps', fps=fps) + + + # --- AUDIO GRAPH CONSTRUCTION (UNCHANGED and CORRECT) --- + track_audio_streams = [] + for i in range(1, timeline.num_audio_tracks + 1): + track_clips = sorted([c for c in timeline.clips if c.track_type == 'audio' and c.track_index == i], key=lambda c: c.timeline_start_ms) + if not track_clips: + continue + + track_segments = [] + last_end_ms = 0 + for clip in track_clips: + gap_ms = clip.timeline_start_ms - last_end_ms + if gap_ms > 10: + track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap_ms/1000.0}', f='lavfi')) + + clip_start_sec = clip.clip_start_ms / 1000.0 + clip_duration_sec = clip.duration_ms / 1000.0 + audio_source_node = ffmpeg.input(clip.source_path) + a_seg = audio_source_node.audio.filter('atrim', start=clip_start_sec, duration=clip_duration_sec).filter('asetpts', 'PTS-STARTPTS') + track_segments.append(a_seg) + last_end_ms = clip.timeline_start_ms + clip.duration_ms + + if track_segments: + track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1)) + + # --- FINAL OUTPUT ASSEMBLY --- + output_args = {} + stream_args = [] + has_audio = bool(track_audio_streams) and export_settings.get('acodec') + + if export_settings.get('vcodec'): + stream_args.append(final_video) + output_args['vcodec'] = export_settings['vcodec'] + if export_settings.get('v_bitrate'): output_args['b:v'] = export_settings['v_bitrate'] + + if has_audio: + final_audio = ffmpeg.filter(track_audio_streams, 'amix', inputs=len(track_audio_streams), duration='longest') + stream_args.append(final_audio) + output_args['acodec'] = export_settings['acodec'] + if export_settings.get('a_bitrate'): output_args['b:a'] = export_settings['a_bitrate'] + + if not has_audio: + output_args['an'] = None + + if not stream_args: + raise ValueError("No streams to output. Check export settings.") + + ffmpeg_cmd = ffmpeg.output(*stream_args, export_settings['output_path'], **output_args).overwrite_output().compile() + + except Exception as e: + self.finished.emit(False, f"Error building FFmpeg command: {e}") + self._is_running = False + return + + self.worker_thread = QThread() + self.worker = _ExportRunner(ffmpeg_cmd, total_dur_ms) + self.worker.moveToThread(self.worker_thread) + + self.worker.progress.connect(self.progress.emit) + self.worker.finished.connect(self._on_export_runner_finished) + + self.worker_thread.started.connect(self.worker.run) + self.worker_thread.start() + + def _on_export_runner_finished(self, success, message): + self._is_running = False + self.finished.emit(success, message) + + if self.worker_thread: + self.worker_thread.quit() + self.worker_thread.wait() + self.worker_thread = None + self.worker = None + + def cancel_export(self): + if self.worker and self.worker.get_process() and self.worker.get_process().poll() is None: + self.worker.get_process().terminate() + print("Export cancelled by user.") \ No newline at end of file diff --git a/playback.py b/playback.py new file mode 100644 index 000000000..d8da682e8 --- /dev/null +++ b/playback.py @@ -0,0 +1,531 @@ +import ffmpeg +import numpy as np +import sounddevice as sd +import threading +import time +from queue import Queue, Empty +import subprocess +from PyQt6.QtCore import QObject, pyqtSignal, QTimer +from PyQt6.QtGui import QImage, QPixmap, QColor + +AUDIO_BUFFER_SECONDS = 1.0 +VIDEO_BUFFER_SECONDS = 1.0 +AUDIO_CHUNK_SAMPLES = 1024 +DEFAULT_SAMPLE_RATE = 44100 +DEFAULT_CHANNELS = 2 + +class PlaybackManager(QObject): + new_frame = pyqtSignal(QPixmap) + playback_pos_changed = pyqtSignal(int) + stopped = pyqtSignal() + started = pyqtSignal() + paused = pyqtSignal() + stats_updated = pyqtSignal(str) + + def __init__(self, get_timeline_data_func, parent=None): + super().__init__(parent) + self.get_timeline_data = get_timeline_data_func + + self.is_playing = False + self.is_muted = False + self.volume = 1.0 + self.playback_start_time_ms = 0 + self.pause_time_ms = 0 + self.is_paused = False + self.seeking = False + self.stop_flag = threading.Event() + self._seek_thread = None + self._seek_request_ms = -1 + self._seek_lock = threading.Lock() + + self.video_process = None + self.audio_process = None + self.video_reader_thread = None + self.audio_reader_thread = None + self.audio_stream = None + + self.video_queue = None + self.audio_queue = None + + self.stream_start_time_monotonic = 0 + self.last_emitted_pos = -1 + + self.total_samples_played = 0 + self.audio_underruns = 0 + self.last_video_pts_ms = 0 + self.debug = False + + self.sync_lock = threading.Lock() + self.audio_clock_sec = 0.0 + self.audio_clock_update_time = 0.0 + + self.update_timer = QTimer(self) + self.update_timer.timeout.connect(self._update_loop) + self.update_timer.setInterval(16) + + def _emit_playhead_pos(self, time_ms, source): + if time_ms != self.last_emitted_pos: + if self.debug: + print(f"[PLAYHEAD DEBUG @ {time.monotonic():.4f}] pos={time_ms}ms, source='{source}'") + self.playback_pos_changed.emit(time_ms) + self.last_emitted_pos = time_ms + + def _cleanup_resources(self): + self.stop_flag.set() + + if self.update_timer.isActive(): + self.update_timer.stop() + + if self.audio_stream: + try: + self.audio_stream.stop() + self.audio_stream.close(ignore_errors=True) + except Exception as e: + print(f"Error closing audio stream: {e}") + self.audio_stream = None + + for p in [self.video_process, self.audio_process]: + if p and p.poll() is None: + try: + p.terminate() + p.wait(timeout=1.0) # Wait a bit for graceful termination + except Exception as e: + print(f"Error terminating process: {e}") + try: + p.kill() # Force kill if terminate fails + except Exception as ke: + print(f"Error killing process: {ke}") + + + self.video_reader_thread = None + self.audio_reader_thread = None + self.video_process = None + self.audio_process = None + self.video_queue = None + self.audio_queue = None + + def _video_reader_thread(self, process, width, height, fps): + frame_size = width * height * 3 + frame_duration_ms = 1000.0 / fps + frame_pts_ms = self.playback_start_time_ms + + while not self.stop_flag.is_set(): + try: + if self.video_queue and self.video_queue.full(): + time.sleep(0.01) + continue + + chunk = process.stdout.read(frame_size) + if len(chunk) < frame_size: + break + + if self.video_queue: self.video_queue.put((chunk, frame_pts_ms)) + frame_pts_ms += frame_duration_ms + + except (IOError, ValueError): + break + except Exception as e: + print(f"Video reader error: {e}") + break + if self.debug: print("Video reader thread finished.") + if self.video_queue: self.video_queue.put(None) + + def _audio_reader_thread(self, process): + chunk_size = AUDIO_CHUNK_SAMPLES * DEFAULT_CHANNELS * 4 + while not self.stop_flag.is_set(): + try: + if self.audio_queue and self.audio_queue.full(): + time.sleep(0.01) + continue + + chunk = process.stdout.read(chunk_size) + if not chunk: + break + + np_chunk = np.frombuffer(chunk, dtype=np.float32) + + if self.audio_queue: + self.audio_queue.put(np_chunk) + + except (IOError, ValueError): + break + except Exception as e: + print(f"Audio reader error: {e}") + break + if self.debug: print("Audio reader thread finished.") + if self.audio_queue: self.audio_queue.put(None) + + def _audio_callback(self, outdata, frames, time_info, status): + if status: + self.audio_underruns += 1 + try: + if self.audio_queue is None: raise Empty + chunk = self.audio_queue.get_nowait() + if chunk is None: + raise sd.CallbackStop + + chunk_len = len(chunk) + outdata_len = outdata.shape[0] * outdata.shape[1] + + if chunk_len < outdata_len: + outdata.fill(0) + outdata.flat[:chunk_len] = chunk + else: + outdata[:] = chunk[:outdata.size].reshape(outdata.shape) + + if self.is_muted: + outdata.fill(0) + else: + outdata *= self.volume + + with self.sync_lock: + self.audio_clock_sec = self.total_samples_played / DEFAULT_SAMPLE_RATE + self.audio_clock_update_time = time_info.outputBufferDacTime + + self.total_samples_played += frames + except Empty: + outdata.fill(0) + + def _seek_worker(self): + while True: + with self._seek_lock: + time_ms = self._seek_request_ms + self._seek_request_ms = -1 + + timeline, clips, proj_settings = self.get_timeline_data() + w, h, fps = proj_settings['width'], proj_settings['height'], proj_settings['fps'] + + top_clip = next((c for c in sorted(clips, key=lambda x: x.track_index, reverse=True) + if c.track_type == 'video' and c.timeline_start_ms <= time_ms < (c.timeline_start_ms + c.duration_ms)), None) + + pixmap = QPixmap(w, h) + pixmap.fill(QColor("black")) + + if top_clip: + try: + if top_clip.media_type == 'image': + out, _ = (ffmpeg.input(top_clip.source_path) + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True)) + else: + clip_time_sec = (time_ms - top_clip.timeline_start_ms + top_clip.clip_start_ms) / 1000.0 + out, _ = (ffmpeg.input(top_clip.source_path, ss=f"{clip_time_sec:.6f}") + .filter('scale', w, h, force_original_aspect_ratio='decrease') + .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') + .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') + .run(capture_stdout=True, quiet=True)) + + if out: + image = QImage(out, w, h, QImage.Format.Format_RGB888) + pixmap = QPixmap.fromImage(image) + + except ffmpeg.Error as e: + print(f"Error seeking to frame: {e.stderr.decode('utf-8') if e.stderr else str(e)}") + + self.new_frame.emit(pixmap) + + with self._seek_lock: + if self._seek_request_ms == -1: + self.seeking = False + break + + def seek_to_frame(self, time_ms): + self.stop() + self._emit_playhead_pos(time_ms, "seek_to_frame") + + with self._seek_lock: + self._seek_request_ms = time_ms + if self.seeking: + return + self.seeking = True + self._seek_thread = threading.Thread(target=self._seek_worker, daemon=True) + self._seek_thread.start() + + def _build_video_graph(self, start_ms, timeline, clips, proj_settings): + w, h, fps = proj_settings['width'], proj_settings['height'], proj_settings['fps'] + + video_clips = [c for c in clips if c.track_type == 'video' and c.timeline_end_ms > start_ms] + if not video_clips: + return None + video_clips = sorted( + [c for c in clips if c.track_type == 'video' and c.timeline_end_ms > start_ms], + key=lambda c: c.track_index, + reverse=True + ) + if not video_clips: + return None + + class VSegment: + def __init__(self, c, t_start, t_end): + self.source_path = c.source_path + self.duration_ms = t_end - t_start + self.timeline_start_ms = t_start + self.media_type = c.media_type + self.clip_start_ms = c.clip_start_ms + (t_start - c.timeline_start_ms) + + @property + def timeline_end_ms(self): + return self.timeline_start_ms + self.duration_ms + + event_points = {start_ms} + for c in video_clips: + event_points.update({c.timeline_start_ms, c.timeline_end_ms}) + + total_duration = timeline.get_total_duration() + if total_duration > start_ms: + event_points.add(total_duration) + + sorted_points = sorted([p for p in list(event_points) if p >= start_ms]) + + visible_segments = [] + for t_start, t_end in zip(sorted_points[:-1], sorted_points[1:]): + if t_end <= t_start: continue + + midpoint = t_start + 1 + top_clip = next((c for c in video_clips if c.timeline_start_ms <= midpoint < c.timeline_end_ms), None) + + if top_clip: + visible_segments.append(VSegment(top_clip, t_start, t_end)) + + top_track_clips = visible_segments + + concat_inputs = [] + last_end_time_ms = start_ms + for clip in top_track_clips: + clip_read_start_ms = clip.clip_start_ms + max(0, start_ms - clip.timeline_start_ms) + clip_play_start_ms = max(start_ms, clip.timeline_start_ms) + clip_remaining_duration_ms = clip.timeline_end_ms - clip_play_start_ms + if clip_remaining_duration_ms <= 0: + last_end_time_ms = max(last_end_time_ms, clip.timeline_end_ms) + continue + input_stream = ffmpeg.input(clip.source_path, ss=clip_read_start_ms / 1000.0, t=clip_remaining_duration_ms / 1000.0, re=None) + if clip.media_type == 'image': + segment = input_stream.video.filter('loop', loop=-1, size=1, start=0).filter('setpts', 'N/(FRAME_RATE*TB)').filter('trim', duration=clip_remaining_duration_ms / 1000.0) + else: + segment = input_stream.video + gap_start_ms = max(start_ms, last_end_time_ms) + if clip.timeline_start_ms > gap_start_ms: + gap_duration_sec = (clip.timeline_start_ms - gap_start_ms) / 1000.0 + segment = segment.filter('tpad', start_duration=f'{gap_duration_sec:.6f}', color='black') + concat_inputs.append(segment) + last_end_time_ms = clip.timeline_end_ms + if not concat_inputs: + return None + return ffmpeg.concat(*concat_inputs, v=1, a=0) + + def _build_audio_graph(self, start_ms, timeline, clips, proj_settings): + active_clips = [c for c in clips if c.track_type == 'audio' and c.timeline_end_ms > start_ms] + if not active_clips: + return None + + tracks = {} + for clip in active_clips: + tracks.setdefault(clip.track_index, []).append(clip) + + track_streams = [] + for track_index, track_clips in sorted(tracks.items()): + track_clips.sort(key=lambda c: c.timeline_start_ms) + + concat_inputs = [] + last_end_time_ms = start_ms + + for clip in track_clips: + gap_start_ms = max(start_ms, last_end_time_ms) + if clip.timeline_start_ms > gap_start_ms: + gap_duration_sec = (clip.timeline_start_ms - gap_start_ms) / 1000.0 + concat_inputs.append(ffmpeg.input(f'anullsrc=r={DEFAULT_SAMPLE_RATE}:cl={DEFAULT_CHANNELS}:d={gap_duration_sec}', f='lavfi')) + + clip_read_start_ms = clip.clip_start_ms + max(0, start_ms - clip.timeline_start_ms) + clip_play_start_ms = max(start_ms, clip.timeline_start_ms) + clip_remaining_duration_ms = clip.timeline_end_ms - clip_play_start_ms + + if clip_remaining_duration_ms > 0: + segment = ffmpeg.input(clip.source_path, ss=clip_read_start_ms/1000.0, t=clip_remaining_duration_ms/1000.0, re=None).audio + concat_inputs.append(segment) + + last_end_time_ms = clip.timeline_end_ms + + if concat_inputs: + track_streams.append(ffmpeg.concat(*concat_inputs, v=0, a=1)) + + if not track_streams: + return None + + return ffmpeg.filter(track_streams, 'amix', inputs=len(track_streams), duration='longest') + + def play(self, time_ms): + if self.is_playing: + self.stop() + + if self.debug: print(f"Play requested from {time_ms}ms.") + + self.stop_flag.clear() + self.playback_start_time_ms = time_ms + self.pause_time_ms = time_ms + self.is_paused = False + + self.total_samples_played = 0 + self.audio_underruns = 0 + self.last_video_pts_ms = time_ms + with self.sync_lock: + self.audio_clock_sec = 0.0 + self.audio_clock_update_time = 0.0 + + timeline, clips, proj_settings = self.get_timeline_data() + w, h, fps = proj_settings['width'], proj_settings['height'], proj_settings['fps'] + + video_buffer_size = int(fps * VIDEO_BUFFER_SECONDS) + audio_buffer_size = int((DEFAULT_SAMPLE_RATE / AUDIO_CHUNK_SAMPLES) * AUDIO_BUFFER_SECONDS) + self.video_queue = Queue(maxsize=video_buffer_size) + self.audio_queue = Queue(maxsize=audio_buffer_size) + + video_graph = self._build_video_graph(time_ms, timeline, clips, proj_settings) + if video_graph: + try: + args = (video_graph + .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=fps).compile()) + self.video_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + except Exception as e: + print(f"Failed to start video process: {e}") + self.video_process = None + + audio_graph = self._build_audio_graph(time_ms, timeline, clips, proj_settings) + if audio_graph: + try: + args = ffmpeg.output(audio_graph, 'pipe:', format='f32le', ac=DEFAULT_CHANNELS, ar=DEFAULT_SAMPLE_RATE).compile() + self.audio_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) + except Exception as e: + print(f"Failed to start audio process: {e}") + self.audio_process = None + + if self.video_process: + self.video_reader_thread = threading.Thread(target=self._video_reader_thread, args=(self.video_process, w, h, fps), daemon=True) + self.video_reader_thread.start() + if self.audio_process: + self.audio_reader_thread = threading.Thread(target=self._audio_reader_thread, args=(self.audio_process,), daemon=True) + self.audio_reader_thread.start() + self.audio_stream = sd.OutputStream(samplerate=DEFAULT_SAMPLE_RATE, channels=DEFAULT_CHANNELS, callback=self._audio_callback, blocksize=AUDIO_CHUNK_SAMPLES) + self.audio_stream.start() + + self.stream_start_time_monotonic = time.monotonic() + self.is_playing = True + self.update_timer.start() + self.started.emit() + + def pause(self): + if not self.is_playing or self.is_paused: + return + if self.debug: print("Playback paused.") + self.is_paused = True + if self.audio_stream: + self.audio_stream.stop() + self.pause_time_ms = self._get_current_time_ms() + self.update_timer.stop() + self.paused.emit() + + def resume(self): + if not self.is_playing or not self.is_paused: + return + if self.debug: print("Playback resumed.") + self.is_paused = False + + time_since_start_sec = (self.pause_time_ms - self.playback_start_time_ms) / 1000.0 + self.stream_start_time_monotonic = time.monotonic() - time_since_start_sec + + if self.audio_stream: + self.audio_stream.start() + + self.update_timer.start() + self.started.emit() + + def stop(self): + was_playing = self.is_playing + if was_playing and self.debug: print("Playback stopped.") + self._cleanup_resources() + self.is_playing = False + self.is_paused = False + if was_playing: + self.stopped.emit() + + def set_volume(self, value): + self.volume = max(0.0, min(1.0, value)) + + def set_muted(self, muted): + self.is_muted = bool(muted) + + def _get_current_time_ms(self): + with self.sync_lock: + audio_clk_sec = self.audio_clock_sec + audio_clk_update_time = self.audio_clock_update_time + + if self.audio_stream and self.audio_stream.active and audio_clk_update_time > 0: + time_since_dac_start = max(0, time.monotonic() - audio_clk_update_time) + current_audio_time_sec = audio_clk_sec + time_since_dac_start + current_pos_ms = self.playback_start_time_ms + int(current_audio_time_sec * 1000.0) + else: + elapsed_sec = time.monotonic() - self.stream_start_time_monotonic + current_pos_ms = self.playback_start_time_ms + int(elapsed_sec * 1000) + + return current_pos_ms + + def _update_loop(self): + if not self.is_playing or self.is_paused: + return + + current_pos_ms = self._get_current_time_ms() + + clock_source = "SYSTEM" + with self.sync_lock: + if self.audio_stream and self.audio_stream.active and self.audio_clock_update_time > 0: + clock_source = "AUDIO" + + if abs(current_pos_ms - self.last_emitted_pos) >= 20: + source_str = f"_update_loop (clock:{clock_source})" + self._emit_playhead_pos(current_pos_ms, source_str) + + if self.video_queue: + while not self.video_queue.empty(): + try: + # Peek at the next frame + if self.video_queue.queue[0] is None: + self.stop() # End of stream + break + _, frame_pts = self.video_queue.queue[0] + + if frame_pts <= current_pos_ms: + frame_bytes, frame_pts = self.video_queue.get_nowait() + if frame_bytes is None: # Should be caught by peek, but for safety + self.stop() + break + self.last_video_pts_ms = frame_pts + _, _, proj_settings = self.get_timeline_data() + w, h = proj_settings['width'], proj_settings['height'] + img = QImage(frame_bytes, w, h, QImage.Format.Format_RGB888) + self.new_frame.emit(QPixmap.fromImage(img)) + else: + # Frame is in the future, wait for next loop + break + except Empty: + break + except IndexError: + break # Queue might be empty between check and access + + # --- Update Stats --- + vq_size = self.video_queue.qsize() if self.video_queue else 0 + vq_max = self.video_queue.maxsize if self.video_queue else 0 + aq_size = self.audio_queue.qsize() if self.audio_queue else 0 + aq_max = self.audio_queue.maxsize if self.audio_queue else 0 + + video_audio_sync_ms = int(self.last_video_pts_ms - current_pos_ms) + + stats_str = (f"AQ: {aq_size}/{aq_max} | VQ: {vq_size}/{vq_max} | " + f"V-A Δ: {video_audio_sync_ms}ms | Clock: {clock_source} | " + f"Underruns: {self.audio_underruns}") + self.stats_updated.emit(stats_str) + + total_duration = self.get_timeline_data()[0].get_total_duration() + if total_duration > 0 and current_pos_ms >= total_duration: + self.stop() + self._emit_playhead_pos(int(total_duration), "_update_loop.end_of_timeline") \ No newline at end of file diff --git a/plugins.py b/plugins.py index f40f6adb1..a4e0699dd 100644 --- a/plugins.py +++ b/plugins.py @@ -5,47 +5,59 @@ import shutil import git import json +import importlib from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QListWidget, QListWidgetItem, QPushButton, QLabel, QLineEdit, QMessageBox, QProgressBar, QDialogButtonBox, QWidget, QCheckBox) from PyQt6.QtCore import Qt, QObject, pyqtSignal, QThread class VideoEditorPlugin: - """ - Base class for all plugins. - Plugins should inherit from this class and be located in 'plugins/plugin_name/main.py'. - The main class in main.py must be named 'Plugin'. - """ - def __init__(self, app_instance, wgp_module=None): + def __init__(self, app_instance): self.app = app_instance - self.wgp = wgp_module self.name = "Unnamed Plugin" self.description = "No description provided." def initialize(self): - """Called once when the plugin is loaded by the PluginManager.""" pass def enable(self): - """Called when the plugin is enabled by the user (e.g., checking the box in the menu).""" pass def disable(self): - """Called when the plugin is disabled by the user.""" pass class PluginManager: - """Manages the discovery, loading, and lifecycle of plugins.""" - def __init__(self, main_app, wgp_module=None): + def __init__(self, main_app): self.app = main_app - self.wgp = wgp_module self.plugins_dir = "plugins" - self.plugins = {} # { 'plugin_name': {'instance': plugin_instance, 'enabled': False, 'module_path': path} } + self.plugins = {} if not os.path.exists(self.plugins_dir): os.makedirs(self.plugins_dir) + def run_preloads(plugins_dir="plugins"): + print("Running plugin pre-load scan...") + if not os.path.isdir(plugins_dir): + return + + for plugin_name in os.listdir(plugins_dir): + plugin_path = os.path.join(plugins_dir, plugin_name) + manifest_path = os.path.join(plugin_path, 'plugin.json') + if os.path.isfile(manifest_path): + try: + with open(manifest_path, 'r') as f: + manifest = json.load(f) + + preload_modules = manifest.get("preload_modules", []) + if preload_modules: + for module_name in preload_modules: + try: + importlib.import_module(module_name) + except ImportError as e: + print(f" - WARNING: Could not pre-load '{module_name}'. Error: {e}") + except Exception as e: + print(f" - WARNING: Could not read or parse manifest for plugin '{plugin_name}': {e}") + def discover_and_load_plugins(self): - """Scans the plugins directory, loads valid plugins, and calls their initialize method.""" for plugin_name in os.listdir(self.plugins_dir): plugin_path = os.path.join(self.plugins_dir, plugin_name) main_py_path = os.path.join(plugin_path, 'main.py') @@ -60,7 +72,7 @@ def discover_and_load_plugins(self): if hasattr(module, 'Plugin'): plugin_class = getattr(module, 'Plugin') - instance = plugin_class(self.app, self.wgp) + instance = plugin_class(self.app) instance.initialize() self.plugins[instance.name] = { 'instance': instance, @@ -76,17 +88,14 @@ def discover_and_load_plugins(self): traceback.print_exc() def load_enabled_plugins_from_settings(self, enabled_plugins_list): - """Enables plugins based on the loaded settings.""" for name in enabled_plugins_list: if name in self.plugins: self.enable_plugin(name) def get_enabled_plugin_names(self): - """Returns a list of names of all enabled plugins.""" return [name for name, data in self.plugins.items() if data['enabled']] def enable_plugin(self, name): - """Enables a specific plugin by name and updates the UI.""" if name in self.plugins and not self.plugins[name]['enabled']: self.plugins[name]['instance'].enable() self.plugins[name]['enabled'] = True @@ -95,7 +104,6 @@ def enable_plugin(self, name): self.app.update_plugin_ui_visibility(name, True) def disable_plugin(self, name): - """Disables a specific plugin by name and updates the UI.""" if name in self.plugins and self.plugins[name]['enabled']: self.plugins[name]['instance'].disable() self.plugins[name]['enabled'] = False @@ -104,7 +112,6 @@ def disable_plugin(self, name): self.app.update_plugin_ui_visibility(name, False) def uninstall_plugin(self, name): - """Uninstalls (deletes) a plugin by name.""" if name in self.plugins: path = self.plugins[name]['module_path'] if self.plugins[name]['enabled']: @@ -174,12 +181,10 @@ def __init__(self, plugin_manager, parent=None): self.status_label = QLabel("Ready.") layout.addWidget(self.status_label) - - # Dialog buttons + self.button_box = QDialogButtonBox(QDialogButtonBox.StandardButton.Save | QDialogButtonBox.StandardButton.Cancel) layout.addWidget(self.button_box) - # Connections self.install_btn.clicked.connect(self.install_plugin) self.button_box.accepted.connect(self.save_changes) self.button_box.rejected.connect(self.reject) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 473baf2e4..2a38230dd 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -50,10 +50,11 @@ def __getattr__(self, name): sys.modules['shared.gradio.gallery'] = MockGradioModule() # --- End of Gradio Hijacking --- +wgp = None import ffmpeg from PyQt6.QtWidgets import ( - QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, + QApplication, QWidget, QVBoxLayout, QHBoxLayout, QTabWidget, QPushButton, QLabel, QLineEdit, QTextEdit, QSlider, QCheckBox, QComboBox, QFileDialog, QGroupBox, QFormLayout, QTableWidget, QTableWidgetItem, QHeaderView, QProgressBar, QScrollArea, QListWidget, QListWidgetItem, @@ -213,7 +214,6 @@ class Worker(QObject): def __init__(self, plugin, state): super().__init__() self.plugin = plugin - self.wgp = plugin.wgp self.state = state self._is_running = True self._last_progress_phase = None @@ -222,7 +222,7 @@ def __init__(self, plugin, state): def run(self): def generation_target(): try: - for _ in self.wgp.process_tasks(self.state): + for _ in wgp.process_tasks(self.state): if self._is_running: self.output.emit() else: break except Exception as e: @@ -243,7 +243,7 @@ def generation_target(): phase_name, step = current_phase total_steps = gen.get("num_inference_steps", 1) high_level_status = gen.get("progress_status", "") - status_msg = self.wgp.merge_status_context(high_level_status, phase_name) + status_msg = wgp.merge_status_context(high_level_status, phase_name) progress_args = [(step, total_steps), status_msg] self.progress.emit(progress_args) preview_img = gen.get('preview') @@ -260,7 +260,6 @@ class WgpDesktopPluginWidget(QWidget): def __init__(self, plugin): super().__init__() self.plugin = plugin - self.wgp = plugin.wgp self.widgets = {} self.state = {} self.worker = None @@ -739,7 +738,7 @@ def _setup_adv_tab_misc(self, tabs): layout.addRow("Force FPS:", fps_combo) profile_combo = self.create_widget(QComboBox, 'override_profile') profile_combo.addItem("Default Profile", -1) - for text, val in self.wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val) + for text, val in wgp.memory_profile_choices: profile_combo.addItem(text.split(':')[0], val) layout.addRow("Override Memory Profile:", profile_combo) combo = self.create_widget(QComboBox, 'multi_prompts_gen_type') combo.addItem("Generate new Video per line", 0) @@ -777,7 +776,7 @@ def _create_scrollable_form_tab(self): def _create_config_combo(self, form_layout, label, key, choices, default_value): combo = QComboBox() for text, data in choices: combo.addItem(text, data) - index = combo.findData(self.wgp.server_config.get(key, default_value)) + index = combo.findData(wgp.server_config.get(key, default_value)) if index != -1: combo.setCurrentIndex(index) self.widgets[f'config_{key}'] = combo form_layout.addRow(label, combo) @@ -789,7 +788,7 @@ def _create_config_slider(self, form_layout, label, key, min_val, max_val, defau slider = QSlider(Qt.Orientation.Horizontal) slider.setRange(min_val, max_val) slider.setSingleStep(step) - slider.setValue(self.wgp.server_config.get(key, default_value)) + slider.setValue(wgp.server_config.get(key, default_value)) value_label = QLabel(str(slider.value())) value_label.setMinimumWidth(40) slider.valueChanged.connect(lambda v, lbl=value_label: lbl.setText(str(v))) @@ -801,7 +800,7 @@ def _create_config_slider(self, form_layout, label, key, min_val, max_val, defau def _create_config_checklist(self, form_layout, label, key, choices, default_value): list_widget = QListWidget() list_widget.setMinimumHeight(100) - current_values = self.wgp.server_config.get(key, default_value) + current_values = wgp.server_config.get(key, default_value) for text, data in choices: item = QListWidgetItem(text) item.setData(Qt.ItemDataRole.UserRole, data) @@ -822,8 +821,8 @@ def _create_config_textbox(self, form_layout, label, key, default_value, multi_l def _create_general_config_tab(self): tab, form = self._create_scrollable_form_tab() - _, _, dropdown_choices = self.wgp.get_sorted_dropdown(self.wgp.displayed_model_types, None, None, False) - self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, self.wgp.transformer_types) + _, _, dropdown_choices = wgp.get_sorted_dropdown(wgp.displayed_model_types, None, None, False) + self._create_config_checklist(form, "Selectable Models:", "transformer_types", dropdown_choices, wgp.transformer_types) self._create_config_combo(form, "Model Hierarchy:", "model_hierarchy_type", [("Two Levels (Family > Model)", 0), ("Three Levels (Family > Base > Finetune)", 1)], 1) self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0) self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto") @@ -832,7 +831,7 @@ def _create_general_config_tab(self): self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5) self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0) self._create_config_combo(form, "Max Frames Multiplier:", "max_frames_multiplier", [(f"x{i}", i) for i in range(1, 8)], 1) - checkpoints_paths_text = "\n".join(self.wgp.server_config.get("checkpoints_paths", self.wgp.fl.default_checkpoints_paths)) + checkpoints_paths_text = "\n".join(wgp.server_config.get("checkpoints_paths", wgp.fl.default_checkpoints_paths)) checkpoints_textbox = QTextEdit() checkpoints_textbox.setPlainText(checkpoints_paths_text) checkpoints_textbox.setAcceptRichText(False) @@ -853,7 +852,7 @@ def _create_performance_config_tab(self): self._create_config_combo(form, "DepthAnything v2 Variant:", "depth_anything_v2_variant", [("Large (more precise)", "vitl"), ("Big (faster)", "vitb")], "vitl") self._create_config_combo(form, "VAE Tiling:", "vae_config", [("Auto", 0), ("Disabled", 1), ("256x256 (~8GB VRAM)", 2), ("128x128 (~6GB VRAM)", 3)], 0) self._create_config_combo(form, "Boost:", "boost", [("On", 1), ("Off", 2)], 1) - self._create_config_combo(form, "Memory Profile:", "profile", self.wgp.memory_profile_choices, self.wgp.profile_type.LowRAM_LowVRAM) + self._create_config_combo(form, "Memory Profile:", "profile", wgp.memory_profile_choices, wgp.profile_type.LowRAM_LowVRAM) self._create_config_slider(form, "Preload in VRAM (MB):", "preload_in_VRAM", 0, 40000, 0, 100) release_ram_btn = QPushButton("Force Release Models from RAM") release_ram_btn.clicked.connect(self._on_release_ram) @@ -882,7 +881,6 @@ def _create_notifications_config_tab(self): return tab def init_wgp_state(self): - wgp = self.wgp initial_model = wgp.server_config.get("last_model_type", wgp.transformer_type) dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) @@ -903,14 +901,16 @@ def init_wgp_state(self): self._update_input_visibility() def update_model_dropdowns(self, current_model_type): - wgp = self.wgp family_mock, base_type_mock, choice_mock = wgp.generate_dropdown_model_list(current_model_type) for combo_name, mock in [('model_family', family_mock), ('model_base_type_choice', base_type_mock), ('model_choice', choice_mock)]: combo = self.widgets[combo_name] combo.blockSignals(True) combo.clear() if mock.choices: - for display_name, internal_key in mock.choices: combo.addItem(display_name, internal_key) + for item in mock.choices: + if isinstance(item, (list, tuple)) and len(item) >= 2: + display_name, internal_key = item[0], item[1] + combo.addItem(display_name, internal_key) index = combo.findData(mock.value) if index != -1: combo.setCurrentIndex(index) @@ -925,7 +925,6 @@ def update_model_dropdowns(self, current_model_type): def refresh_ui_from_model_change(self, model_type): """Update UI controls with default settings when the model is changed.""" - wgp = self.wgp self.header_info.setText(wgp.generate_header(model_type, wgp.compile, wgp.attention_mode)) ui_defaults = wgp.get_default_settings(model_type) wgp.set_model_settings(self.state, model_type, ui_defaults) @@ -1286,7 +1285,7 @@ def _on_preview_toggled(self, checked): def _on_family_changed(self): family = self.widgets['model_family'].currentData() if not family or not self.state: return - base_type_mock, choice_mock = self.wgp.change_model_family(self.state, family) + base_type_mock, choice_mock = wgp.change_model_family(self.state, family) if hasattr(base_type_mock, 'kwargs') and isinstance(base_type_mock.kwargs, dict): is_visible_base = base_type_mock.kwargs.get('visible', True) @@ -1324,7 +1323,7 @@ def _on_base_type_changed(self): family = self.widgets['model_family'].currentData() base_type = self.widgets['model_base_type_choice'].currentData() if not family or not base_type or not self.state: return - base_type_mock, choice_mock = self.wgp.change_model_base_types(self.state, family, base_type) + base_type_mock, choice_mock = wgp.change_model_base_types(self.state, family, base_type) if hasattr(choice_mock, 'kwargs') and isinstance(choice_mock.kwargs, dict): is_visible_choice = choice_mock.kwargs.get('visible', True) @@ -1345,18 +1344,18 @@ def _on_base_type_changed(self): def _on_model_changed(self): model_type = self.widgets['model_choice'].currentData() if not model_type or model_type == self.state.get('model_type'): return - self.wgp.change_model(self.state, model_type) + wgp.change_model(self.state, model_type) self.refresh_ui_from_model_change(model_type) def _on_resolution_group_changed(self): selected_group = self.widgets['resolution_group'].currentText() if not selected_group or not hasattr(self, 'full_resolution_choices'): return model_type = self.state['model_type'] - model_def = self.wgp.get_model_def(model_type) + model_def = wgp.get_model_def(model_type) model_resolutions = model_def.get("resolutions", None) group_resolution_choices = [] if model_resolutions is None: - group_resolution_choices = [res for res in self.full_resolution_choices if self.wgp.categorize_resolution(res[1]) == selected_group] + group_resolution_choices = [res for res in self.full_resolution_choices if wgp.categorize_resolution(res[1]) == selected_group] else: return last_resolution = self.state.get("last_resolution_per_group", {}).get(selected_group, "") if not any(last_resolution == res[1] for res in group_resolution_choices) and group_resolution_choices: @@ -1398,7 +1397,7 @@ def set_resolution_from_target(self, target_w, target_h): best_res_value = res_value if best_res_value: - best_group = self.wgp.categorize_resolution(best_res_value) + best_group = wgp.categorize_resolution(best_res_value) group_combo = self.widgets['resolution_group'] group_combo.blockSignals(True) @@ -1417,7 +1416,7 @@ def set_resolution_from_target(self, target_w, target_h): print(f"Warning: Could not find resolution '{best_res_value}' in dropdown after group change.") def collect_inputs(self): - full_inputs = self.wgp.get_current_model_settings(self.state).copy() + full_inputs = wgp.get_current_model_settings(self.state).copy() full_inputs['lset_name'] = "" full_inputs['image_mode'] = 0 expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, } @@ -1512,9 +1511,9 @@ def _add_task_to_queue(self): all_inputs = self.collect_inputs() for key in ['type', 'settings_version', 'is_image', 'video_quality', 'image_quality', 'base_model_type']: all_inputs.pop(key, None) all_inputs['state'] = self.state - self.wgp.set_model_settings(self.state, self.state['model_type'], all_inputs) + wgp.set_model_settings(self.state, self.state['model_type'], all_inputs) self.state["validate_success"] = 1 - self.wgp.process_prompt_and_add_tasks(self.state, self.state['model_type']) + wgp.process_prompt_and_add_tasks(self.state, self.state['model_type']) self.update_queue_table() def start_generation(self): @@ -1582,11 +1581,11 @@ def add_result_item(self, video_path): self.results_list.setItemWidget(list_item, item_widget) def update_queue_table(self): - with self.wgp.lock: + with wgp.lock: queue = self.state.get('gen', {}).get('queue', []) is_running = self.thread and self.thread.isRunning() queue_to_display = queue if is_running else [None] + queue - table_data = self.wgp.get_queue_table(queue_to_display) + table_data = wgp.get_queue_table(queue_to_display) self.queue_table.setRowCount(0) self.queue_table.setRowCount(len(table_data)) self.queue_table.setColumnCount(4) @@ -1601,7 +1600,7 @@ def update_queue_table(self): def _on_remove_selected_from_queue(self): selected_row = self.queue_table.currentRow() if selected_row < 0: return - with self.wgp.lock: + with wgp.lock: is_running = self.thread and self.thread.isRunning() offset = 1 if is_running else 0 queue = self.state.get('gen', {}).get('queue', []) @@ -1609,7 +1608,7 @@ def _on_remove_selected_from_queue(self): self.update_queue_table() def _on_queue_rows_moved(self, source_row, dest_row): - with self.wgp.lock: + with wgp.lock: queue = self.state.get('gen', {}).get('queue', []) is_running = self.thread and self.thread.isRunning() offset = 1 if is_running else 0 @@ -1620,17 +1619,17 @@ def _on_queue_rows_moved(self, source_row, dest_row): self.update_queue_table() def _on_clear_queue(self): - self.wgp.clear_queue_action(self.state) + wgp.clear_queue_action(self.state) self.update_queue_table() def _on_abort(self): if self.worker: - self.wgp.abort_generation(self.state) + wgp.abort_generation(self.state) self.status_label.setText("Aborting...") self.worker._is_running = False def _on_release_ram(self): - self.wgp.release_RAM() + wgp.release_RAM() QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.") def _on_apply_config_changes(self): @@ -1655,12 +1654,12 @@ def _on_apply_config_changes(self): changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value() changes['last_resolution_choice'] = self.widgets['resolution'].currentData() try: - msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = self.wgp.apply_changes(self.state, **changes) + msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = wgp.apply_changes(self.state, **changes) self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.") - self.header_info.setText(self.wgp.generate_header(self.state['model_type'], self.wgp.compile, self.wgp.attention_mode)) + self.header_info.setText(wgp.generate_header(self.state['model_type'], wgp.compile, wgp.attention_mode)) if family_mock.choices is not None or choice_mock.choices is not None: - self.update_model_dropdowns(self.wgp.transformer_type) - self.refresh_ui_from_model_change(self.wgp.transformer_type) + self.update_model_dropdowns(wgp.transformer_type) + self.refresh_ui_from_model_change(wgp.transformer_type) except Exception as e: self.config_status_label.setText(f"Error applying changes: {e}") import traceback; traceback.print_exc() @@ -1669,23 +1668,82 @@ class Plugin(VideoEditorPlugin): def initialize(self): self.name = "AI Generator" self.description = "Uses the integrated Wan2GP library to generate video clips." - self.client_widget = WgpDesktopPluginWidget(self) + self.client_widget = None self.dock_widget = None - self.active_region = None; self.temp_dir = None - self.insert_on_new_track = False; self.start_frame_path = None; self.end_frame_path = None + self._heavy_content_loaded = False + + self.active_region = None + self.temp_dir = None + self.insert_on_new_track = False + self.start_frame_path = None + self.end_frame_path = None def enable(self): - if not self.dock_widget: self.dock_widget = self.app.add_dock_widget(self, self.client_widget, self.name) + if not self.dock_widget: + placeholder = QLabel("Loading AI Generator...") + placeholder.setAlignment(Qt.AlignmentFlag.AlignCenter) + self.dock_widget = self.app.add_dock_widget(self, placeholder, self.name) + self.dock_widget.visibilityChanged.connect(self._on_visibility_changed) + self.app.timeline_widget.context_menu_requested.connect(self.on_timeline_context_menu) self.app.status_label.setText(f"{self.name}: Enabled.") + def _on_visibility_changed(self, visible): + if visible and not self._heavy_content_loaded: + self._load_heavy_ui() + + def _load_heavy_ui(self): + if self._heavy_content_loaded: + return True + + self.app.status_label.setText("Loading AI Generator...") + QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor) + QApplication.processEvents() + try: + global wgp + if wgp is None: + import wgp as wgp_module + wgp = wgp_module + + self.client_widget = WgpDesktopPluginWidget(self) + self.dock_widget.setWidget(self.client_widget) + self._heavy_content_loaded = True + self.app.status_label.setText("AI Generator loaded.") + return True + except Exception as e: + print(f"Failed to load AI Generator plugin backend: {e}") + import traceback + traceback.print_exc() + error_label = QLabel(f"Failed to load AI Generator:\n\n{e}\n\nPlease see console for details.") + error_label.setAlignment(Qt.AlignmentFlag.AlignCenter) + error_label.setWordWrap(True) + if self.dock_widget: + self.dock_widget.setWidget(error_label) + QMessageBox.critical(self.app, "Plugin Load Error", f"Failed to load the AI Generator backend.\n\n{e}") + return False + finally: + QApplication.restoreOverrideCursor() + def disable(self): try: self.app.timeline_widget.context_menu_requested.disconnect(self.on_timeline_context_menu) except TypeError: pass + + if self.dock_widget: + try: self.dock_widget.visibilityChanged.disconnect(self._on_visibility_changed) + except TypeError: pass + self._cleanup_temp_dir() - if self.client_widget.worker: self.client_widget._on_abort() + if self.client_widget and self.client_widget.worker: + self.client_widget._on_abort() + self.app.status_label.setText(f"{self.name}: Disabled.") + def _ensure_ui_loaded(self): + if not self._heavy_content_loaded: + if not self._load_heavy_ui(): + return False + return True + def _cleanup_temp_dir(self): if self.temp_dir and os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir) @@ -1694,11 +1752,12 @@ def _cleanup_temp_dir(self): def _reset_state(self): self.active_region = None; self.insert_on_new_track = False self.start_frame_path = None; self.end_frame_path = None - self.client_widget.processed_files.clear() - self.client_widget.results_list.clear() - self.client_widget.widgets['image_start'].clear() - self.client_widget.widgets['image_end'].clear() - self.client_widget.widgets['video_source'].clear() + if self._heavy_content_loaded: + self.client_widget.processed_files.clear() + self.client_widget.results_list.clear() + self.client_widget.widgets['image_start'].clear() + self.client_widget.widgets['image_end'].clear() + self.client_widget.widgets['video_source'].clear() self._cleanup_temp_dir() self.app.status_label.setText(f"{self.name}: Ready.") @@ -1732,13 +1791,14 @@ def on_timeline_context_menu(self, menu, event): create_action_new_track.triggered.connect(lambda: self.setup_creator_for_region(region, on_new_track=True)) def setup_generator_for_region(self, region, on_new_track=False): + if not self._ensure_ui_loaded(): return self._reset_state() self.active_region = region self.insert_on_new_track = on_new_track model_to_set = 'i2v_2_2' - dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types - _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types + _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) if any(model_to_set == m[1] for m in all_models): if self.client_widget.state.get('model_type') != model_to_set: self.client_widget.update_model_dropdowns(model_to_set) @@ -1763,7 +1823,6 @@ def setup_generator_for_region(self, region, on_new_track=False): QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) duration_ms = end_ms - start_ms - wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) @@ -1792,13 +1851,14 @@ def setup_generator_for_region(self, region, on_new_track=False): self.dock_widget.raise_() def setup_generator_from_start(self, region, on_new_track=False): + if not self._ensure_ui_loaded(): return self._reset_state() self.active_region = region self.insert_on_new_track = on_new_track model_to_set = 'i2v_2_2' - dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types - _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types + _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) if any(model_to_set == m[1] for m in all_models): if self.client_widget.state.get('model_type') != model_to_set: self.client_widget.update_model_dropdowns(model_to_set) @@ -1820,7 +1880,6 @@ def setup_generator_from_start(self, region, on_new_track=False): QImage(start_data, w, h, QImage.Format.Format_RGB888).save(self.start_frame_path) duration_ms = end_ms - start_ms - wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) @@ -1851,13 +1910,14 @@ def setup_generator_from_start(self, region, on_new_track=False): self.dock_widget.raise_() def setup_generator_to_end(self, region, on_new_track=False): + if not self._ensure_ui_loaded(): return self._reset_state() self.active_region = region self.insert_on_new_track = on_new_track model_to_set = 'i2v_2_2' - dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types - _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types + _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) if any(model_to_set == m[1] for m in all_models): if self.client_widget.state.get('model_type') != model_to_set: self.client_widget.update_model_dropdowns(model_to_set) @@ -1879,7 +1939,6 @@ def setup_generator_to_end(self, region, on_new_track=False): QImage(end_data, w, h, QImage.Format.Format_RGB888).save(self.end_frame_path) duration_ms = end_ms - start_ms - wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) @@ -1887,7 +1946,7 @@ def setup_generator_to_end(self, region, on_new_track=False): widgets['video_length'].setValue(video_length_frames) - model_def = self.client_widget.wgp.get_model_def(self.client_widget.state['model_type']) + model_def = wgp.get_model_def(self.client_widget.state['model_type']) allowed_modes = model_def.get("image_prompt_types_allowed", "") if "E" not in allowed_modes: @@ -1921,13 +1980,14 @@ def setup_generator_to_end(self, region, on_new_track=False): self.dock_widget.raise_() def setup_creator_for_region(self, region, on_new_track=False): + if not self._ensure_ui_loaded(): return self._reset_state() self.active_region = region self.insert_on_new_track = on_new_track model_to_set = 't2v_2_2' - dropdown_types = self.wgp.transformer_types if len(self.wgp.transformer_types) > 0 else self.wgp.displayed_model_types - _, _, all_models = self.wgp.get_sorted_dropdown(dropdown_types, None, None, False) + dropdown_types = wgp.transformer_types if len(wgp.transformer_types) > 0 else wgp.displayed_model_types + _, _, all_models = wgp.get_sorted_dropdown(dropdown_types, None, None, False) if any(model_to_set == m[1] for m in all_models): if self.client_widget.state.get('model_type') != model_to_set: self.client_widget.update_model_dropdowns(model_to_set) @@ -1941,7 +2001,6 @@ def setup_creator_for_region(self, region, on_new_track=False): start_ms, end_ms = region duration_ms = end_ms - start_ms - wgp = self.client_widget.wgp model_type = self.client_widget.state['model_type'] fps = wgp.get_model_fps(model_type) video_length_frames = int((duration_ms / 1000.0) * fps) if fps > 0 else int((duration_ms / 1000.0) * 16) diff --git a/plugins/wan2gp/plugin.json b/plugins/wan2gp/plugin.json new file mode 100644 index 000000000..47e592c50 --- /dev/null +++ b/plugins/wan2gp/plugin.json @@ -0,0 +1,5 @@ +{ + "preload_modules": [ + "onnxruntime" + ] +} \ No newline at end of file diff --git a/videoeditor.py b/videoeditor.py index d6e744ebe..21df9e63d 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -1,4 +1,4 @@ -import wgp +import onnxruntime import sys import os import uuid @@ -7,19 +7,21 @@ import json import ffmpeg import copy +from plugins import PluginManager, ManagePluginsDialog from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, QPushButton, QFileDialog, QLabel, QScrollArea, QFrame, QProgressBar, QDialog, QCheckBox, QDialogButtonBox, QMenu, QSplitter, QDockWidget, QListWidget, QListWidgetItem, QMessageBox, QComboBox, - QFormLayout, QGroupBox, QLineEdit) + QFormLayout, QGroupBox, QLineEdit, QSlider) from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction, QPixmap, QImage, QDrag, QCursor, QKeyEvent) from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread, pyqtSignal, QTimer, QByteArray, QMimeData) -from plugins import PluginManager, ManagePluginsDialog from undo import UndoStack, TimelineStateChangeCommand, MoveClipsCommand +from playback import PlaybackManager +from encoding import Encoder CONTAINER_PRESETS = { 'mp4': { @@ -210,36 +212,6 @@ def get_total_duration(self): if not self.clips: return 0 return max(c.timeline_end_ms for c in self.clips) -class ExportWorker(QObject): - progress = pyqtSignal(int) - finished = pyqtSignal(str) - def __init__(self, ffmpeg_cmd, total_duration_ms): - super().__init__() - self.ffmpeg_cmd = ffmpeg_cmd - self.total_duration_ms = total_duration_ms - - def run_export(self): - try: - process = subprocess.Popen(self.ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True, encoding="utf-8") - time_pattern = re.compile(r"time=(\d{2}):(\d{2}):(\d{2})\.(\d{2})") - for line in iter(process.stdout.readline, ""): - match = time_pattern.search(line) - if match: - h, m, s, cs = [int(g) for g in match.groups()] - processed_ms = (h * 3600 + m * 60 + s) * 1000 + cs * 10 - if self.total_duration_ms > 0: - percentage = int((processed_ms / self.total_duration_ms) * 100) - self.progress.emit(percentage) - process.stdout.close() - return_code = process.wait() - if return_code == 0: - self.progress.emit(100) - self.finished.emit("Export completed successfully!") - else: - self.finished.emit(f"Export failed! Check console for FFmpeg errors.") - except Exception as e: - self.finished.emit(f"An exception occurred during export: {e}") - class TimelineWidget(QWidget): TIMESCALE_HEIGHT = 30 HEADER_WIDTH = 120 @@ -333,6 +305,10 @@ def set_project_fps(self, fps): def ms_to_x(self, ms): return self.HEADER_WIDTH + int(max(-500_000_000, min((ms - self.view_start_ms) * self.pixels_per_ms, 500_000_000))) def x_to_ms(self, x): return self.view_start_ms + int(float(x - self.HEADER_WIDTH) / self.pixels_per_ms) if x > self.HEADER_WIDTH and self.pixels_per_ms > 0 else self.view_start_ms + def set_playhead_pos(self, time_ms): + self.playhead_pos_ms = time_ms + self.update() + def paintEvent(self, event): painter = QPainter(self) painter.setRenderHint(QPainter.RenderHint.Antialiasing) @@ -1848,8 +1824,6 @@ def __init__(self, project_to_load=None): self.undo_stack = UndoStack() self.media_pool = [] self.media_properties = {} - self.export_thread = None - self.export_worker = None self.current_project_path = None self.last_export_path = None self.settings = {} @@ -1857,7 +1831,10 @@ def __init__(self, project_to_load=None): self.is_shutting_down = False self._load_settings() - self.plugin_manager = PluginManager(self, wgp) + self.playback_manager = PlaybackManager(self._get_playback_data) + self.encoder = Encoder() + + self.plugin_manager = PluginManager(self) self.plugin_manager.discover_and_load_plugins() # Project settings with defaults @@ -1869,16 +1846,12 @@ def __init__(self, project_to_load=None): self.scale_to_fit = True self.current_preview_pixmap = None - self.playback_timer = QTimer(self) - self.playback_process = None - self.playback_clip = None - self._setup_ui() self._connect_signals() self.plugin_manager.load_enabled_plugins_from_settings(self.settings.get("enabled_plugins", [])) self._apply_loaded_settings() - self.seek_preview(0) + self.playback_manager.seek_to_frame(0) if not self.settings_file_was_loaded: self._save_settings() if project_to_load: QTimer.singleShot(100, lambda: self._load_project_from_path(project_to_load)) @@ -1893,6 +1866,17 @@ def _get_current_timeline_state(self): self.timeline.num_video_tracks, self.timeline.num_audio_tracks ) + + def _get_playback_data(self): + return ( + self.timeline, + self.timeline.clips, + { + 'width': self.project_width, + 'height': self.project_height, + 'fps': self.project_fps + } + ) def _setup_ui(self): self.media_dock = QDockWidget("Project Media", self) @@ -1946,15 +1930,33 @@ def _setup_ui(self): controls_layout.addStretch() main_layout.addWidget(controls_widget) - status_layout = QHBoxLayout() + status_bar_widget = QWidget() + status_layout = QHBoxLayout(status_bar_widget) + status_layout.setContentsMargins(5,2,5,2) + self.status_label = QLabel("Ready. Create or open a project from the File menu.") + self.stats_label = QLabel("AQ: 0/0 | VQ: 0/0") + self.stats_label.setMinimumWidth(120) self.progress_bar = QProgressBar() self.progress_bar.setVisible(False) self.progress_bar.setTextVisible(True) self.progress_bar.setRange(0, 100) + + self.mute_button = QPushButton("Mute") + self.mute_button.setCheckable(True) + self.volume_slider = QSlider(Qt.Orientation.Horizontal) + self.volume_slider.setRange(0, 100) + self.volume_slider.setValue(100) + self.volume_slider.setMaximumWidth(100) + status_layout.addWidget(self.status_label, 1) + status_layout.addWidget(self.stats_label) status_layout.addWidget(self.progress_bar, 1) - main_layout.addLayout(status_layout) + status_layout.addStretch() + status_layout.addWidget(self.mute_button) + status_layout.addWidget(self.volume_slider) + + main_layout.addWidget(status_bar_widget) self.setCentralWidget(container_widget) @@ -1978,7 +1980,7 @@ def _connect_signals(self): self.timeline_widget.split_requested.connect(self.split_clip_at_playhead) self.timeline_widget.delete_clip_requested.connect(self.delete_clip) self.timeline_widget.delete_clips_requested.connect(self.delete_clips) - self.timeline_widget.playhead_moved.connect(self.seek_preview) + self.timeline_widget.playhead_moved.connect(self.playback_manager.seek_to_frame) self.timeline_widget.split_region_requested.connect(self.on_split_region) self.timeline_widget.split_all_regions_requested.connect(self.on_split_all_regions) self.timeline_widget.join_region_requested.connect(self.on_join_region) @@ -1995,7 +1997,6 @@ def _connect_signals(self): self.frame_forward_button.clicked.connect(lambda: self.step_frame(1)) self.snap_back_button.clicked.connect(lambda: self.snap_playhead(-1)) self.snap_forward_button.clicked.connect(lambda: self.snap_playhead(1)) - self.playback_timer.timeout.connect(self.advance_playback_frame) self.project_media_widget.add_media_requested.connect(self.add_media_files) self.project_media_widget.media_removed.connect(self.on_media_removed_from_pool) @@ -2004,6 +2005,19 @@ def _connect_signals(self): self.undo_stack.history_changed.connect(self.update_undo_redo_actions) self.undo_stack.timeline_changed.connect(self.on_timeline_changed_by_undo) + self.playback_manager.new_frame.connect(self._on_new_frame) + self.playback_manager.playback_pos_changed.connect(self._on_playback_pos_changed) + self.playback_manager.stopped.connect(self._on_playback_stopped) + self.playback_manager.started.connect(self._on_playback_started) + self.playback_manager.paused.connect(self._on_playback_paused) + self.playback_manager.stats_updated.connect(self.stats_label.setText) + + self.encoder.progress.connect(self.progress_bar.setValue) + self.encoder.finished.connect(self.on_export_finished) + + self.mute_button.toggled.connect(self._on_mute_toggled) + self.volume_slider.valueChanged.connect(self._on_volume_changed) + def _show_preview_context_menu(self, pos): menu = QMenu(self) scale_action = QAction("Scale to Fit", self, checkable=True) @@ -2019,7 +2033,7 @@ def _toggle_scale_to_fit(self, checked): def on_timeline_changed_by_undo(self): self.prune_empty_tracks() self.timeline_widget.update() - self.seek_preview(self.timeline_widget.playhead_pos_ms) + self.playback_manager.seek_to_frame(self.timeline_widget.playhead_pos_ms) self.status_label.setText("Operation undone/redone.") def update_undo_redo_actions(self): @@ -2191,41 +2205,6 @@ def _create_menu_bar(self): data['action'] = action self.windows_menu.addAction(action) - def _get_topmost_video_clip_at(self, time_ms): - """Finds the video clip on the highest track at a specific time.""" - top_clip = None - for c in self.timeline.clips: - if c.track_type == 'video' and c.timeline_start_ms <= time_ms < c.timeline_end_ms: - if top_clip is None or c.track_index > top_clip.track_index: - top_clip = c - return top_clip - - def _start_playback_stream_at(self, time_ms): - self._stop_playback_stream() - clip = self._get_topmost_video_clip_at(time_ms) - if not clip: return - - self.playback_clip = clip - clip_time_sec = (time_ms - clip.timeline_start_ms + clip.clip_start_ms) / 1000.0 - w, h = self.project_width, self.project_height - try: - args = (ffmpeg.input(self.playback_clip.source_path, ss=f"{clip_time_sec:.6f}") - .filter('scale', w, h, force_original_aspect_ratio='decrease') - .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') - .output('pipe:', format='rawvideo', pix_fmt='rgb24', r=self.project_fps).compile()) - self.playback_process = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL) - except Exception as e: - print(f"Failed to start playback stream: {e}"); self._stop_playback_stream() - - def _stop_playback_stream(self): - if self.playback_process: - if self.playback_process.poll() is None: - self.playback_process.terminate() - try: self.playback_process.wait(timeout=0.5) - except subprocess.TimeoutExpired: self.playback_process.kill(); self.playback_process.wait() - self.playback_process = None - self.playback_clip = None - def _get_media_properties(self, file_path): """Probes a file to get its media properties. Returns a dict or None.""" if file_path in self.media_properties: @@ -2307,11 +2286,16 @@ def _probe_for_drag(self, file_path): return self._get_media_properties(file_path) def get_frame_data_at_time(self, time_ms): - clip_at_time = self._get_topmost_video_clip_at(time_ms) + """Blocking frame grab for plugin compatibility.""" + _, clips, proj_settings = self._get_playback_data() + w, h = proj_settings['width'], proj_settings['height'] + + clip_at_time = next((c for c in sorted(clips, key=lambda x: x.track_index, reverse=True) + if c.track_type == 'video' and c.timeline_start_ms <= time_ms < c.timeline_end_ms), None) + if not clip_at_time: return (None, 0, 0) try: - w, h = self.project_width, self.project_height if clip_at_time.media_type == 'image': out, _ = ( ffmpeg @@ -2333,41 +2317,16 @@ def get_frame_data_at_time(self, time_ms): ) return (out, self.project_width, self.project_height) except ffmpeg.Error as e: - print(f"Error extracting frame data: {e.stderr}") + print(f"Error extracting frame data for plugin: {e.stderr}") return (None, 0, 0) - def get_frame_at_time(self, time_ms): - clip_at_time = self._get_topmost_video_clip_at(time_ms) - if not clip_at_time: return None - try: - w, h = self.project_width, self.project_height - if clip_at_time.media_type == 'image': - out, _ = ( - ffmpeg.input(clip_at_time.source_path) - .filter('scale', w, h, force_original_aspect_ratio='decrease') - .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') - .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') - .run(capture_stdout=True, quiet=True) - ) - else: - clip_time_sec = (time_ms - clip_at_time.timeline_start_ms + clip_at_time.clip_start_ms) / 1000.0 - out, _ = (ffmpeg.input(clip_at_time.source_path, ss=f"{clip_time_sec:.6f}") - .filter('scale', w, h, force_original_aspect_ratio='decrease') - .filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black') - .output('pipe:', vframes=1, format='rawvideo', pix_fmt='rgb24') - .run(capture_stdout=True, quiet=True)) - - image = QImage(out, self.project_width, self.project_height, QImage.Format.Format_RGB888) - return QPixmap.fromImage(image) - except ffmpeg.Error as e: print(f"Error extracting frame: {e.stderr}"); return None - - def seek_preview(self, time_ms): - self._stop_playback_stream() - self.timeline_widget.playhead_pos_ms = int(time_ms) - self.timeline_widget.update() - self.current_preview_pixmap = self.get_frame_at_time(time_ms) + def _on_new_frame(self, pixmap): + self.current_preview_pixmap = pixmap self._update_preview_display() + def _on_playback_pos_changed(self, time_ms): + self.timeline_widget.set_playhead_pos(time_ms) + def _update_preview_display(self): pixmap_to_show = self.current_preview_pixmap if not pixmap_to_show: @@ -2388,22 +2347,42 @@ def _update_preview_display(self): self.preview_scroll_area.setWidgetResizable(False) self.preview_widget.setPixmap(pixmap_to_show) self.preview_widget.adjustSize() - def toggle_playback(self): - if self.playback_timer.isActive(): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play") + if not self.timeline.clips: + return + + if self.playback_manager.is_playing: + if self.playback_manager.is_paused: + self.playback_manager.resume() + else: + self.playback_manager.pause() else: - if not self.timeline.clips: return - if self.timeline_widget.playhead_pos_ms >= self.timeline.get_total_duration(): self.timeline_widget.playhead_pos_ms = 0 - self.playback_timer.start(int(1000 / self.project_fps)); self.play_pause_button.setText("Pause") + current_pos = self.timeline_widget.playhead_pos_ms + if current_pos >= self.timeline.get_total_duration(): + current_pos = 0 + self.playback_manager.play(current_pos) + + def _on_playback_started(self): + self.play_pause_button.setText("Pause") + + def _on_playback_paused(self): + self.play_pause_button.setText("Play") + + def _on_playback_stopped(self): + self.play_pause_button.setText("Play") + + def stop_playback(self): + self.playback_manager.stop() + self.playback_manager.seek_to_frame(0) - def stop_playback(self): self.playback_timer.stop(); self._stop_playback_stream(); self.play_pause_button.setText("Play"); self.seek_preview(0) def step_frame(self, direction): if not self.timeline.clips: return - self.playback_timer.stop(); self.play_pause_button.setText("Play"); self._stop_playback_stream() + self.playback_manager.pause() frame_duration_ms = 1000.0 / self.project_fps new_time = self.timeline_widget.playhead_pos_ms + (direction * frame_duration_ms) - self.seek_preview(int(max(0, min(new_time, self.timeline.get_total_duration())))) + final_time = int(max(0, min(new_time, self.timeline.get_total_duration()))) + self.playback_manager.seek_to_frame(final_time) def snap_playhead(self, direction): if not self.timeline.clips: @@ -2423,50 +2402,20 @@ def snap_playhead(self, direction): if direction == 1: # Forward next_points = [p for p in sorted_points if p > current_time_ms + TOLERANCE_MS] if next_points: - self.seek_preview(next_points[0]) + self.playback_manager.seek_to_frame(next_points[0]) elif direction == -1: # Backward prev_points = [p for p in sorted_points if p < current_time_ms - TOLERANCE_MS] if prev_points: - self.seek_preview(prev_points[-1]) + self.playback_manager.seek_to_frame(prev_points[-1]) elif current_time_ms > 0: - self.seek_preview(0) - - def advance_playback_frame(self): - frame_duration_ms = 1000.0 / self.project_fps - new_time_ms = self.timeline_widget.playhead_pos_ms + frame_duration_ms - if new_time_ms > self.timeline.get_total_duration(): self.stop_playback(); return - - self.timeline_widget.playhead_pos_ms = round(new_time_ms) - self.timeline_widget.update() - - clip_at_new_time = self._get_topmost_video_clip_at(new_time_ms) - - if not clip_at_new_time: - self._stop_playback_stream() - self.current_preview_pixmap = None - self._update_preview_display() - return + self.playback_manager.seek_to_frame(0) - if clip_at_new_time.media_type == 'image': - if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: - self._stop_playback_stream() - self.playback_clip = clip_at_new_time - self.current_preview_pixmap = self.get_frame_at_time(new_time_ms) - self._update_preview_display() - return + def _on_volume_changed(self, value): + self.playback_manager.set_volume(value / 100.0) - if self.playback_clip is None or self.playback_clip.id != clip_at_new_time.id: - self._start_playback_stream_at(new_time_ms) - - if self.playback_process: - frame_size = self.project_width * self.project_height * 3 - frame_bytes = self.playback_process.stdout.read(frame_size) - if len(frame_bytes) == frame_size: - image = QImage(frame_bytes, self.project_width, self.project_height, QImage.Format.Format_RGB888) - self.current_preview_pixmap = QPixmap.fromImage(image) - self._update_preview_display() - else: - self._stop_playback_stream() + def _on_mute_toggled(self, checked): + self.playback_manager.set_muted(checked) + self.mute_button.setText("Unmute" if checked else "Mute") def _load_settings(self): self.settings_file_was_loaded = False @@ -2518,9 +2467,10 @@ def open_settings_dialog(self): self.settings.update(dialog.get_settings()); self._save_settings(); self.status_label.setText("Settings updated.") def new_project(self): + self.playback_manager.stop() self.timeline.clips.clear(); self.timeline.num_video_tracks = 1; self.timeline.num_audio_tracks = 1 self.media_pool.clear(); self.media_properties.clear(); self.project_media_widget.clear_list() - self.current_project_path = None; self.stop_playback() + self.current_project_path = None self.last_export_path = None self.project_fps = 25.0 self.project_width = 1280 @@ -2532,6 +2482,7 @@ def new_project(self): self.undo_stack.history_changed.connect(self.update_undo_redo_actions) self.update_undo_redo_actions() self.status_label.setText("New project created. Add media to begin.") + self.playback_manager.seek_to_frame(0) def save_project_as(self): path, _ = QFileDialog.getSaveFileName(self, "Save Project", "", "JSON Project Files (*.json)") @@ -2592,7 +2543,8 @@ def _load_project_from_path(self, path): self.current_project_path = path self.prune_empty_tracks() - self.timeline_widget.update(); self.stop_playback() + self.timeline_widget.update() + self.playback_manager.seek_to_frame(0) self.status_label.setText(f"Project '{os.path.basename(path)}' loaded.") self._add_to_recent_files(path) except Exception as e: self.status_label.setText(f"Error opening project: {e}") @@ -3045,103 +2997,29 @@ def export_video(self): self.status_label.setText("Export canceled.") return - settings = dialog.get_export_settings() - output_path = settings["output_path"] + export_settings = dialog.get_export_settings() + output_path = export_settings["output_path"] if not output_path: self.status_label.setText("Export failed: No output path specified.") return self.last_export_path = output_path - total_dur_ms = self.timeline.get_total_duration() - total_dur_sec = total_dur_ms / 1000.0 - w, h, fr_str = self.project_width, self.project_height, str(self.project_fps) - sample_rate, channel_layout = '44100', 'stereo' - - video_stream = ffmpeg.input(f'color=c=black:s={w}x{h}:r={fr_str}:d={total_dur_sec}', f='lavfi') - - all_video_clips = sorted( - [c for c in self.timeline.clips if c.track_type == 'video'], - key=lambda c: c.track_index - ) - input_nodes = {} - for clip in all_video_clips: - if clip.source_path not in input_nodes: - if clip.media_type == 'image': - input_nodes[clip.source_path] = ffmpeg.input(clip.source_path, loop=1, framerate=self.project_fps) - else: - input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) - - clip_source_node = input_nodes[clip.source_path] - clip_duration_sec = clip.duration_ms / 1000.0 - clip_start_sec = clip.clip_start_ms / 1000.0 - timeline_start_sec = clip.timeline_start_ms / 1000.0 - timeline_end_sec = clip.timeline_end_ms / 1000.0 - - if clip.media_type == 'image': - segment_stream = (clip_source_node.video.trim(duration=clip_duration_sec).setpts('PTS-STARTPTS')) - else: - segment_stream = (clip_source_node.video.trim(start=clip_start_sec, duration=clip_duration_sec).setpts('PTS-STARTPTS')) - - processed_segment = (segment_stream.filter('scale', w, h, force_original_aspect_ratio='decrease').filter('pad', w, h, '(ow-iw)/2', '(oh-ih)/2', 'black')) - video_stream = ffmpeg.overlay(video_stream, processed_segment, enable=f'between(t,{timeline_start_sec},{timeline_end_sec})') - - final_video = video_stream.filter('format', pix_fmts='yuv420p').filter('fps', fps=self.project_fps) + project_settings = { + 'width': self.project_width, + 'height': self.project_height, + 'fps': self.project_fps + } - track_audio_streams = [] - for i in range(self.timeline.num_audio_tracks): - track_clips = sorted([c for c in self.timeline.clips if c.track_type == 'audio' and c.track_index == i + 1], key=lambda c: c.timeline_start_ms) - if not track_clips: continue - track_segments = [] - last_end_ms = track_clips[0].timeline_start_ms - for clip in track_clips: - if clip.source_path not in input_nodes: - input_nodes[clip.source_path] = ffmpeg.input(clip.source_path) - gap_ms = clip.timeline_start_ms - last_end_ms - if gap_ms > 10: - track_segments.append(ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={gap_ms/1000.0}', f='lavfi')) - - clip_start_sec = clip.clip_start_ms / 1000.0 - clip_duration_sec = clip.duration_ms / 1000.0 - a_seg = input_nodes[clip.source_path].audio.filter('atrim', start=clip_start_sec, duration=clip_duration_sec).filter('asetpts', 'PTS-STARTPTS') - track_segments.append(a_seg) - last_end_ms = clip.timeline_end_ms - track_audio_streams.append(ffmpeg.concat(*track_segments, v=0, a=1).filter('adelay', f'{int(track_clips[0].timeline_start_ms)}ms', all=True)) - - if track_audio_streams: - final_audio = ffmpeg.filter(track_audio_streams, 'amix', inputs=len(track_audio_streams), duration='longest') - else: - final_audio = ffmpeg.input(f'anullsrc=r={sample_rate}:cl={channel_layout}:d={total_dur_sec}', f='lavfi') + self.progress_bar.setVisible(True) + self.progress_bar.setValue(0) + self.status_label.setText("Exporting...") - output_args = {'vcodec': settings['vcodec'], 'acodec': settings['acodec'], 'pix_fmt': 'yuv420p'} - if settings['v_bitrate']: output_args['b:v'] = settings['v_bitrate'] - if settings['a_bitrate']: output_args['b:a'] = settings['a_bitrate'] - - try: - ffmpeg_cmd = ffmpeg.output(final_video, final_audio, output_path, **output_args).overwrite_output().compile() - self.progress_bar.setVisible(True); self.progress_bar.setValue(0); self.status_label.setText("Exporting...") - self.export_thread = QThread() - self.export_worker = ExportWorker(ffmpeg_cmd, total_dur_ms) - self.export_worker.moveToThread(self.export_thread) - self.export_thread.started.connect(self.export_worker.run_export) - self.export_worker.finished.connect(self.on_export_finished) - self.export_worker.progress.connect(self.progress_bar.setValue) - self.export_worker.finished.connect(self.export_thread.quit) - self.export_worker.finished.connect(self.export_worker.deleteLater) - self.export_thread.finished.connect(self.export_thread.deleteLater) - self.export_thread.finished.connect(self.on_thread_finished_cleanup) - self.export_thread.start() - except ffmpeg.Error as e: - self.status_label.setText(f"FFmpeg error: {e.stderr}") - print(e.stderr) + self.encoder.start_export(self.timeline, project_settings, export_settings) - def on_export_finished(self, message): + def on_export_finished(self, success, message): self.status_label.setText(message) self.progress_bar.setVisible(False) - - def on_thread_finished_cleanup(self): - self.export_thread = None - self.export_worker = None def add_dock_widget(self, plugin_instance, widget, title, area=Qt.DockWidgetArea.RightDockWidgetArea, show_on_creation=True): widget_key = f"plugin_{plugin_instance.name}_{title}".replace(' ', '_').lower() @@ -3205,14 +3083,12 @@ def closeEvent(self, event): return self.is_shutting_down = True + self.playback_manager.stop() self._save_settings() - self._stop_playback_stream() - if self.export_thread and self.export_thread.isRunning(): - self.export_thread.quit() - self.export_thread.wait() event.accept() if __name__ == '__main__': + PluginManager.run_preloads() app = QApplication(sys.argv) project_to_load_on_startup = None if len(sys.argv) > 1: From 47a64b5bd11f73ca3fdb51eb8d7256ffbb929e25 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Tue, 14 Oct 2025 19:18:06 +1100 Subject: [PATCH 148/155] remove failed attempt to make onnxruntime import plugin-based --- plugins.py | 24 ------------------------ plugins/wan2gp/plugin.json | 5 ----- videoeditor.py | 1 - 3 files changed, 30 deletions(-) delete mode 100644 plugins/wan2gp/plugin.json diff --git a/plugins.py b/plugins.py index a4e0699dd..cde026750 100644 --- a/plugins.py +++ b/plugins.py @@ -5,7 +5,6 @@ import shutil import git import json -import importlib from PyQt6.QtWidgets import (QDialog, QVBoxLayout, QHBoxLayout, QListWidget, QListWidgetItem, QPushButton, QLabel, QLineEdit, QMessageBox, QProgressBar, QDialogButtonBox, QWidget, QCheckBox) @@ -34,29 +33,6 @@ def __init__(self, main_app): if not os.path.exists(self.plugins_dir): os.makedirs(self.plugins_dir) - def run_preloads(plugins_dir="plugins"): - print("Running plugin pre-load scan...") - if not os.path.isdir(plugins_dir): - return - - for plugin_name in os.listdir(plugins_dir): - plugin_path = os.path.join(plugins_dir, plugin_name) - manifest_path = os.path.join(plugin_path, 'plugin.json') - if os.path.isfile(manifest_path): - try: - with open(manifest_path, 'r') as f: - manifest = json.load(f) - - preload_modules = manifest.get("preload_modules", []) - if preload_modules: - for module_name in preload_modules: - try: - importlib.import_module(module_name) - except ImportError as e: - print(f" - WARNING: Could not pre-load '{module_name}'. Error: {e}") - except Exception as e: - print(f" - WARNING: Could not read or parse manifest for plugin '{plugin_name}': {e}") - def discover_and_load_plugins(self): for plugin_name in os.listdir(self.plugins_dir): plugin_path = os.path.join(self.plugins_dir, plugin_name) diff --git a/plugins/wan2gp/plugin.json b/plugins/wan2gp/plugin.json deleted file mode 100644 index 47e592c50..000000000 --- a/plugins/wan2gp/plugin.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "preload_modules": [ - "onnxruntime" - ] -} \ No newline at end of file diff --git a/videoeditor.py b/videoeditor.py index 21df9e63d..101b2059a 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -3088,7 +3088,6 @@ def closeEvent(self, event): event.accept() if __name__ == '__main__': - PluginManager.run_preloads() app = QApplication(sys.argv) project_to_load_on_startup = None if len(sys.argv) > 1: From f151887f7f7ebe46fe4b1ad1e0a98bb4525a37e5 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Thu, 16 Oct 2025 07:19:30 +1100 Subject: [PATCH 149/155] Update README.md --- README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 81c93be78..ca7c0161e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Inline AI Video Editor +image# Inline AI Video Editor A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provides a multi-track timeline, a media preview window, and basic clip manipulation capabilities, all wrapped in a dockable user interface. The editor is designed to be extensible through a plugin system. @@ -58,9 +58,11 @@ python videoeditor.py ``` ## Screenshots -image -image -image +![sc1](https://github.com/user-attachments/assets/d624100d-5b4d-48f0-85e2-706482d4311a) +![sc2](https://github.com/user-attachments/assets/a9509acd-b423-4313-a0fc-e51479373089) +![sc3](https://github.com/user-attachments/assets/e639d3b4-c2cd-445f-9f00-72597ff70248) + + From 7029b660f743765effd8ced523219f80f636fd3e Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Thu, 16 Oct 2025 07:25:05 +1100 Subject: [PATCH 150/155] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index ca7c0161e..a37096feb 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -image# Inline AI Video Editor +# Inline AI Video Editor A simple, non-linear video editor built with Python, PyQt6, and FFmpeg. It provides a multi-track timeline, a media preview window, and basic clip manipulation capabilities, all wrapped in a dockable user interface. The editor is designed to be extensible through a plugin system. From ad19ab47f5ed3ac15f076435c6592119a13e01e8 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 17 Oct 2025 12:58:57 +1100 Subject: [PATCH 151/155] use icons for multimedia controls --- icons/pause.svg | 8 +++++ icons/play.svg | 7 +++++ icons/previous_frame.svg | 7 +++++ icons/snap_to_start.svg | 8 +++++ icons/stop.svg | 5 +++ videoeditor.py | 66 ++++++++++++++++++++++++++++++++-------- 6 files changed, 89 insertions(+), 12 deletions(-) create mode 100644 icons/pause.svg create mode 100644 icons/play.svg create mode 100644 icons/previous_frame.svg create mode 100644 icons/snap_to_start.svg create mode 100644 icons/stop.svg diff --git a/icons/pause.svg b/icons/pause.svg new file mode 100644 index 000000000..df6b56823 --- /dev/null +++ b/icons/pause.svg @@ -0,0 +1,8 @@ + + Pause + + + + + + \ No newline at end of file diff --git a/icons/play.svg b/icons/play.svg new file mode 100644 index 000000000..bb3a02a8f --- /dev/null +++ b/icons/play.svg @@ -0,0 +1,7 @@ + + Play + + + + + \ No newline at end of file diff --git a/icons/previous_frame.svg b/icons/previous_frame.svg new file mode 100644 index 000000000..63bae024e --- /dev/null +++ b/icons/previous_frame.svg @@ -0,0 +1,7 @@ + + Previous frame + + + + + \ No newline at end of file diff --git a/icons/snap_to_start.svg b/icons/snap_to_start.svg new file mode 100644 index 000000000..3624c85cc --- /dev/null +++ b/icons/snap_to_start.svg @@ -0,0 +1,8 @@ + + Snap to first frame + + + + + + \ No newline at end of file diff --git a/icons/stop.svg b/icons/stop.svg new file mode 100644 index 000000000..7f51b9433 --- /dev/null +++ b/icons/stop.svg @@ -0,0 +1,5 @@ + + Stop + + + \ No newline at end of file diff --git a/videoeditor.py b/videoeditor.py index 101b2059a..4473d9cc0 100644 --- a/videoeditor.py +++ b/videoeditor.py @@ -15,7 +15,7 @@ QListWidget, QListWidgetItem, QMessageBox, QComboBox, QFormLayout, QGroupBox, QLineEdit, QSlider) from PyQt6.QtGui import (QPainter, QColor, QPen, QFont, QFontMetrics, QMouseEvent, QAction, - QPixmap, QImage, QDrag, QCursor, QKeyEvent) + QPixmap, QImage, QDrag, QCursor, QKeyEvent, QIcon, QTransform) from PyQt6.QtCore import (Qt, QPoint, QRect, QRectF, QSize, QPointF, QObject, QThread, pyqtSignal, QTimer, QByteArray, QMimeData) @@ -1886,9 +1886,8 @@ def _setup_ui(self): self.splitter = QSplitter(Qt.Orientation.Vertical) - # --- PREVIEW WIDGET SETUP (CHANGED) --- self.preview_scroll_area = QScrollArea() - self.preview_scroll_area.setWidgetResizable(False) # Important for 1:1 scaling + self.preview_scroll_area.setWidgetResizable(False) self.preview_scroll_area.setStyleSheet("background-color: black; border: 0px;") self.preview_widget = QLabel() @@ -1914,12 +1913,55 @@ def _setup_ui(self): controls_widget = QWidget() controls_layout = QHBoxLayout(controls_widget) controls_layout.setContentsMargins(0, 5, 0, 5) - self.play_pause_button = QPushButton("Play") - self.stop_button = QPushButton("Stop") - self.frame_back_button = QPushButton("<") - self.frame_forward_button = QPushButton(">") - self.snap_back_button = QPushButton("|<") - self.snap_forward_button = QPushButton(">|") + + icon_dir = "icons" + icon_size = QSize(32, 32) + button_size = QSize(40, 40) + + self.play_icon = QIcon(os.path.join(icon_dir, "play.svg")) + self.pause_icon = QIcon(os.path.join(icon_dir, "pause.svg")) + stop_icon = QIcon(os.path.join(icon_dir, "stop.svg")) + back_pixmap = QPixmap(os.path.join(icon_dir, "previous_frame.svg")) + snap_back_pixmap = QPixmap(os.path.join(icon_dir, "snap_to_start.svg")) + + transform = QTransform().rotate(180) + + frame_forward_icon = QIcon(back_pixmap.transformed(transform)) + snap_forward_icon = QIcon(snap_back_pixmap.transformed(transform)) + frame_back_icon = QIcon(back_pixmap) + snap_back_icon = QIcon(snap_back_pixmap) + + self.play_pause_button = QPushButton() + self.play_pause_button.setIcon(self.play_icon) + self.play_pause_button.setToolTip("Play/Pause") + + self.stop_button = QPushButton() + self.stop_button.setIcon(stop_icon) + self.stop_button.setToolTip("Stop") + + self.frame_back_button = QPushButton() + self.frame_back_button.setIcon(frame_back_icon) + self.frame_back_button.setToolTip("Previous Frame (Left Arrow)") + + self.frame_forward_button = QPushButton() + self.frame_forward_button.setIcon(frame_forward_icon) + self.frame_forward_button.setToolTip("Next Frame (Right Arrow)") + + self.snap_back_button = QPushButton() + self.snap_back_button.setIcon(snap_back_icon) + self.snap_back_button.setToolTip("Snap to Previous Clip Edge") + + self.snap_forward_button = QPushButton() + self.snap_forward_button.setIcon(snap_forward_icon) + self.snap_forward_button.setToolTip("Snap to Next Clip Edge") + + button_list = [self.snap_back_button, self.frame_back_button, self.play_pause_button, + self.stop_button, self.frame_forward_button, self.snap_forward_button] + for btn in button_list: + btn.setIconSize(icon_size) + btn.setFixedSize(button_size) + btn.setStyleSheet("QPushButton { border: none; background-color: transparent; }") + controls_layout.addStretch() controls_layout.addWidget(self.snap_back_button) controls_layout.addWidget(self.frame_back_button) @@ -2364,13 +2406,13 @@ def toggle_playback(self): self.playback_manager.play(current_pos) def _on_playback_started(self): - self.play_pause_button.setText("Pause") + self.play_pause_button.setIcon(self.pause_icon) def _on_playback_paused(self): - self.play_pause_button.setText("Play") + self.play_pause_button.setIcon(self.play_icon) def _on_playback_stopped(self): - self.play_pause_button.setText("Play") + self.play_pause_button.setIcon(self.play_icon) def stop_playback(self): self.playback_manager.stop() From 08a7e00ce3a2336205ec03f835e25f594f09564d Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 17 Oct 2025 13:00:12 +1100 Subject: [PATCH 152/155] Update README.md --- README.md | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index a37096feb..a97c6b402 100644 --- a/README.md +++ b/README.md @@ -58,11 +58,9 @@ python videoeditor.py ``` ## Screenshots -![sc1](https://github.com/user-attachments/assets/d624100d-5b4d-48f0-85e2-706482d4311a) -![sc2](https://github.com/user-attachments/assets/a9509acd-b423-4313-a0fc-e51479373089) -![sc3](https://github.com/user-attachments/assets/e639d3b4-c2cd-445f-9f00-72597ff70248) - - +![sc1_](https://github.com/user-attachments/assets/98247de8-613d-418a-b71e-fdf2d6b547f4) +![sc2_](https://github.com/user-attachments/assets/41c3f885-2fa9-4a81-911c-68c13da1e97b) +![sc3_](https://github.com/user-attachments/assets/ec129087-03f5-43e5-a102-934efc62b001) From b5c7a740b71c344383e1003b0446150aa86671ed Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 17 Oct 2025 20:01:23 +1100 Subject: [PATCH 153/155] fix configuration --- plugins/wan2gp/main.py | 120 +++++++++++++++++++++++++++++++---------- 1 file changed, 92 insertions(+), 28 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 2a38230dd..95d0beeb3 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -827,6 +827,10 @@ def _create_general_config_tab(self): self._create_config_combo(form, "Video Dimensions:", "fit_canvas", [("Dimensions are Pixels Budget", 0), ("Dimensions are Max Width/Height", 1), ("Dimensions are Output Width/Height (Cropped)", 2)], 0) self._create_config_combo(form, "Attention Type:", "attention_mode", [("Auto (Recommended)", "auto"), ("SDPA", "sdpa"), ("Flash", "flash"), ("Xformers", "xformers"), ("Sage", "sage"), ("Sage2/2++", "sage2")], "auto") self._create_config_combo(form, "Metadata Handling:", "metadata_type", [("Embed in file (Exif/Comment)", "metadata"), ("Export separate JSON", "json"), ("None", "none")], "metadata") + checkbox = QCheckBox() + checkbox.setChecked(wgp.server_config.get("embed_source_images", False)) + self.widgets['config_embed_source_images'] = checkbox + form.addRow("Embed Source Images in MP4:", checkbox) self._create_config_checklist(form, "RAM Loading Policy:", "preload_model_policy", [("Preload on App Launch", "P"), ("Preload on Model Switch", "S"), ("Unload when Queue is Done", "U")], []) self._create_config_combo(form, "Keep Previous Videos:", "clear_file_list", [("None", 0), ("Keep last video", 1), ("Keep last 5", 5), ("Keep last 10", 10), ("Keep last 20", 20), ("Keep last 30", 30)], 5) self._create_config_combo(form, "Display RAM/VRAM Stats:", "display_stats", [("Disabled", 0), ("Enabled", 1)], 0) @@ -839,6 +843,7 @@ def _create_general_config_tab(self): self.widgets['config_checkpoints_paths'] = checkpoints_textbox form.addRow("Checkpoints Paths:", checkpoints_textbox) self._create_config_combo(form, "UI Theme (requires restart):", "UI_theme", [("Blue Sky", "default"), ("Classic Gradio", "gradio")], "default") + self._create_config_combo(form, "Queue Color Scheme:", "queue_color_scheme", [("Pastel (Unique color per item)", "pastel"), ("Alternating Grey Shades", "alternating_grey")], "pastel") return tab def _create_performance_config_tab(self): @@ -870,8 +875,9 @@ def _create_outputs_config_tab(self): tab, form = self._create_scrollable_form_tab() self._create_config_combo(form, "Video Codec:", "video_output_codec", [("x265 Balanced", 'libx265_28'), ("x264 Balanced", 'libx264_8'), ("x265 High Quality", 'libx265_8'), ("x264 High Quality", 'libx264_10'), ("x264 Lossless", 'libx264_lossless')], 'libx264_8') self._create_config_combo(form, "Image Codec:", "image_output_codec", [("JPEG Q85", 'jpeg_85'), ("WEBP Q85", 'webp_85'), ("JPEG Q95", 'jpeg_95'), ("WEBP Q95", 'webp_95'), ("WEBP Lossless", 'webp_lossless'), ("PNG Lossless", 'png')], 'jpeg_95') - self._create_config_textbox(form, "Video Output Folder:", "save_path", "outputs") - self._create_config_textbox(form, "Image Output Folder:", "image_save_path", "outputs") + self._create_config_combo(form, "Audio Codec:", "audio_output_codec", [("AAC 128 kbit", 'aac_128')], 'aac_128') + self._create_config_textbox(form, "Video Output Folder:", "save_path", wgp.server_config.get("save_path", "outputs")) + self._create_config_textbox(form, "Image Output Folder:", "image_save_path", wgp.server_config.get("image_save_path", "outputs")) return tab def _create_notifications_config_tab(self): @@ -1419,7 +1425,8 @@ def collect_inputs(self): full_inputs = wgp.get_current_model_settings(self.state).copy() full_inputs['lset_name'] = "" full_inputs['image_mode'] = 0 - expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, } + full_inputs['mode'] = "" + expected_keys = { "audio_guide": None, "audio_guide2": None, "image_guide": None, "image_mask": None, "speakers_locations": "", "frames_positions": "", "keep_frames_video_guide": "", "keep_frames_video_source": "", "video_guide_outpainting": "", "switch_threshold2": 0, "model_switch_phase": 1, "batch_size": 1, "control_net_weight_alt": 1.0, "image_refs_relative_size": 50, "embedded_guidance_scale": None, "model_mode": None, "control_net_weight": 1.0, "control_net_weight2": 1.0, "mask_expand": 0, "remove_background_images_ref": 0, "prompt_enhancer": ""} for key, default_value in expected_keys.items(): if key not in full_inputs: full_inputs[key] = default_value full_inputs['prompt'] = self.widgets['prompt'].toPlainText() @@ -1633,33 +1640,90 @@ def _on_release_ram(self): QMessageBox.information(self, "RAM Released", "Models stored in RAM have been released.") def _on_apply_config_changes(self): - changes = {} - list_widget = self.widgets['config_transformer_types'] - changes['transformer_types_choices'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] - list_widget = self.widgets['config_preload_model_policy'] - changes['preload_model_policy_choice'] = [item.data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] - changes['model_hierarchy_type_choice'] = self.widgets['config_model_hierarchy_type'].currentData() - changes['checkpoints_paths'] = self.widgets['config_checkpoints_paths'].toPlainText() - for key in ["fit_canvas", "attention_mode", "metadata_type", "clear_file_list", "display_stats", "max_frames_multiplier", "UI_theme"]: - changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - for key in ["transformer_quantization", "transformer_dtype_policy", "mixed_precision", "text_encoder_quantization", "vae_precision", "compile", "depth_anything_v2_variant", "vae_config", "boost", "profile"]: - changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - changes['preload_in_VRAM_choice'] = self.widgets['config_preload_in_VRAM'].value() - for key in ["enhancer_enabled", "enhancer_mode", "mmaudio_enabled"]: - changes[f'{key}_choice'] = self.widgets[f'config_{key}'].currentData() - for key in ["video_output_codec", "image_output_codec", "save_path", "image_save_path"]: - widget = self.widgets[f'config_{key}'] - changes[f'{key}_choice'] = widget.currentData() if isinstance(widget, QComboBox) else widget.text() - changes['notification_sound_enabled_choice'] = self.widgets['config_notification_sound_enabled'].currentData() - changes['notification_sound_volume_choice'] = self.widgets['config_notification_sound_volume'].value() - changes['last_resolution_choice'] = self.widgets['resolution'].currentData() + if wgp.args.lock_config: + self.config_status_label.setText("Configuration is locked by command-line arguments.") + return + if self.thread and self.thread.isRunning(): + self.config_status_label.setText("Cannot change config while a generation is in progress.") + return + try: - msg, header_mock, family_mock, base_type_mock, choice_mock, refresh_trigger = wgp.apply_changes(self.state, **changes) - self.config_status_label.setText("Changes applied successfully. Some settings may require a restart.") + ui_settings = {} + list_widget = self.widgets['config_transformer_types'] + ui_settings['transformer_types'] = [list_widget.item(i).data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] + list_widget = self.widgets['config_preload_model_policy'] + ui_settings['preload_model_policy'] = [list_widget.item(i).data(Qt.ItemDataRole.UserRole) for i in range(list_widget.count()) if list_widget.item(i).checkState() == Qt.CheckState.Checked] + + ui_settings['model_hierarchy_type'] = self.widgets['config_model_hierarchy_type'].currentData() + ui_settings['fit_canvas'] = self.widgets['config_fit_canvas'].currentData() + ui_settings['attention_mode'] = self.widgets['config_attention_mode'].currentData() + ui_settings['metadata_type'] = self.widgets['config_metadata_type'].currentData() + ui_settings['clear_file_list'] = self.widgets['config_clear_file_list'].currentData() + ui_settings['display_stats'] = self.widgets['config_display_stats'].currentData() + ui_settings['max_frames_multiplier'] = self.widgets['config_max_frames_multiplier'].currentData() + ui_settings['checkpoints_paths'] = [p.strip() for p in self.widgets['config_checkpoints_paths'].toPlainText().replace("\r", "").split("\n") if p.strip()] + ui_settings['UI_theme'] = self.widgets['config_UI_theme'].currentData() + ui_settings['queue_color_scheme'] = self.widgets['config_queue_color_scheme'].currentData() + + ui_settings['transformer_quantization'] = self.widgets['config_transformer_quantization'].currentData() + ui_settings['transformer_dtype_policy'] = self.widgets['config_transformer_dtype_policy'].currentData() + ui_settings['mixed_precision'] = self.widgets['config_mixed_precision'].currentData() + ui_settings['text_encoder_quantization'] = self.widgets['config_text_encoder_quantization'].currentData() + ui_settings['vae_precision'] = self.widgets['config_vae_precision'].currentData() + ui_settings['compile'] = self.widgets['config_compile'].currentData() + ui_settings['depth_anything_v2_variant'] = self.widgets['config_depth_anything_v2_variant'].currentData() + ui_settings['vae_config'] = self.widgets['config_vae_config'].currentData() + ui_settings['boost'] = self.widgets['config_boost'].currentData() + ui_settings['profile'] = self.widgets['config_profile'].currentData() + ui_settings['preload_in_VRAM'] = self.widgets['config_preload_in_VRAM'].value() + + ui_settings['enhancer_enabled'] = self.widgets['config_enhancer_enabled'].currentData() + ui_settings['enhancer_mode'] = self.widgets['config_enhancer_mode'].currentData() + ui_settings['mmaudio_enabled'] = self.widgets['config_mmaudio_enabled'].currentData() + + ui_settings['video_output_codec'] = self.widgets['config_video_output_codec'].currentData() + ui_settings['image_output_codec'] = self.widgets['config_image_output_codec'].currentData() + ui_settings['audio_output_codec'] = self.widgets['config_audio_output_codec'].currentData() + ui_settings['embed_source_images'] = self.widgets['config_embed_source_images'].isChecked() + ui_settings['save_path'] = self.widgets['config_save_path'].text() + ui_settings['image_save_path'] = self.widgets['config_image_save_path'].text() + + ui_settings['notification_sound_enabled'] = self.widgets['config_notification_sound_enabled'].currentData() + ui_settings['notification_sound_volume'] = self.widgets['config_notification_sound_volume'].value() + + ui_settings['last_model_type'] = self.state["model_type"] + ui_settings['last_model_per_family'] = self.state["last_model_per_family"] + ui_settings['last_model_per_type'] = self.state["last_model_per_type"] + ui_settings['last_advanced_choice'] = self.state["advanced"] + ui_settings['last_resolution_choice'] = self.widgets['resolution'].currentData() + ui_settings['last_resolution_per_group'] = self.state["last_resolution_per_group"] + + wgp.server_config.update(ui_settings) + + wgp.fl.set_checkpoints_paths(ui_settings['checkpoints_paths']) + wgp.three_levels_hierarchy = ui_settings["model_hierarchy_type"] == 1 + wgp.attention_mode = ui_settings["attention_mode"] + wgp.default_profile = ui_settings["profile"] + wgp.compile = ui_settings["compile"] + wgp.text_encoder_quantization = ui_settings["text_encoder_quantization"] + wgp.vae_config = ui_settings["vae_config"] + wgp.boost = ui_settings["boost"] + wgp.save_path = ui_settings["save_path"] + wgp.image_save_path = ui_settings["image_save_path"] + wgp.preload_model_policy = ui_settings["preload_model_policy"] + wgp.transformer_quantization = ui_settings["transformer_quantization"] + wgp.transformer_dtype_policy = ui_settings["transformer_dtype_policy"] + wgp.transformer_types = ui_settings["transformer_types"] + wgp.reload_needed = True + + with open(wgp.server_config_filename, "w", encoding="utf-8") as writer: + json.dump(wgp.server_config, writer, indent=4) + + self.config_status_label.setText("Settings saved successfully. Restart may be required for some changes.") self.header_info.setText(wgp.generate_header(self.state['model_type'], wgp.compile, wgp.attention_mode)) - if family_mock.choices is not None or choice_mock.choices is not None: - self.update_model_dropdowns(wgp.transformer_type) - self.refresh_ui_from_model_change(wgp.transformer_type) + self.update_model_dropdowns(wgp.transformer_type) + self.refresh_ui_from_model_change(wgp.transformer_type) + except Exception as e: self.config_status_label.setText(f"Error applying changes: {e}") import traceback; traceback.print_exc() From 73eac71c0e3963b4151dba1018ac2778bbcc6e43 Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 17 Oct 2025 23:28:50 +1100 Subject: [PATCH 154/155] fix queue index problems in plugin --- plugins/wan2gp/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 95d0beeb3..0fe492479 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -179,7 +179,6 @@ def dropEvent(self, event: QDropEvent): source_row = self.currentRow() target_item = self.itemAt(event.position().toPoint()) dest_row = target_item.row() if target_item else self.rowCount() - if source_row < dest_row: dest_row -=1 if source_row != dest_row: self.rowsMoved.emit(source_row, dest_row) event.acceptProposedAction() else: From f9f4bd1342fe9bd3a7a902a0b0dac1d63ee2492e Mon Sep 17 00:00:00 2001 From: Chris Malone Date: Fri, 17 Oct 2025 23:50:04 +1100 Subject: [PATCH 155/155] fix queue width --- plugins/wan2gp/main.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/plugins/wan2gp/main.py b/plugins/wan2gp/main.py index 0fe492479..f91236d0c 100644 --- a/plugins/wan2gp/main.py +++ b/plugins/wan2gp/main.py @@ -563,9 +563,18 @@ def setup_generator_tab(self): results_layout.addWidget(self.results_list) right_layout.addWidget(results_group) - right_layout.addWidget(QLabel("Queue:")) + right_layout.addWidget(QLabel("Queue")) self.queue_table = self.create_widget(QueueTableWidget, 'queue_table') + self.queue_table.verticalHeader().setVisible(False) + self.queue_table.setColumnCount(4) + self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"]) + header = self.queue_table.horizontalHeader() + header.setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + header.setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + header.setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) right_layout.addWidget(self.queue_table) + queue_btn_layout = QHBoxLayout() self.remove_queue_btn = self.create_widget(QPushButton, 'remove_queue_btn', "Remove Selected") self.clear_queue_btn = self.create_widget(QPushButton, 'clear_queue_btn', "Clear Queue") @@ -1594,14 +1603,10 @@ def update_queue_table(self): table_data = wgp.get_queue_table(queue_to_display) self.queue_table.setRowCount(0) self.queue_table.setRowCount(len(table_data)) - self.queue_table.setColumnCount(4) - self.queue_table.setHorizontalHeaderLabels(["Qty", "Prompt", "Length", "Steps"]) for row_idx, row_data in enumerate(table_data): prompt_text = str(row_data[1]).split('>')[1].split('<')[0] if '>' in str(row_data[1]) else str(row_data[1]) for col_idx, cell_data in enumerate([row_data[0], prompt_text, row_data[2], row_data[3]]): self.queue_table.setItem(row_idx, col_idx, QTableWidgetItem(str(cell_data))) - self.queue_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) - self.queue_table.resizeColumnsToContents() def _on_remove_selected_from_queue(self): selected_row = self.queue_table.currentRow()