diff --git a/.gitignore b/.gitignore index a84122c61..8049651d2 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,7 @@ gradio_outputs/ ckpts/ loras/ loras_i2v/ + +settings/ + +wgp_config.json diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..927c579fd --- /dev/null +++ b/Dockerfile @@ -0,0 +1,92 @@ +FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 + +# Build arg for GPU architectures - specify which CUDA compute capabilities to compile for +# Common values: +# 7.0 - Tesla V100 +# 7.5 - RTX 2060, 2070, 2080, Titan RTX +# 8.0 - A100, A800 (Ampere data center) +# 8.6 - RTX 3060, 3070, 3080, 3090 (Ampere consumer) +# 8.9 - RTX 4070, 4080, 4090 (Ada Lovelace) +# 9.0 - H100, H800 (Hopper data center) +# 12.0 - RTX 5070, 5080, 5090 (Blackwell) - Note: sm_120 architecture +# +# Examples: +# RTX 3060: --build-arg CUDA_ARCHITECTURES="8.6" +# RTX 4090: --build-arg CUDA_ARCHITECTURES="8.9" +# Multiple: --build-arg CUDA_ARCHITECTURES="8.0;8.6;8.9" +# +# Note: Including 8.9 or 9.0 may cause compilation issues on some setups +# Default includes 8.0 and 8.6 for broad Ampere compatibility +ARG CUDA_ARCHITECTURES="8.0;8.6" + +ENV DEBIAN_FRONTEND=noninteractive + +# Install system dependencies +RUN apt update && \ + apt install -y \ + python3 python3-pip git wget curl cmake ninja-build \ + libgl1 libglib2.0-0 ffmpeg && \ + apt clean + +WORKDIR /workspace + +COPY requirements.txt . + +# Upgrade pip first +RUN pip install --upgrade pip setuptools wheel + +# Install requirements if exists +RUN pip install -r requirements.txt + +# Install PyTorch with CUDA support +RUN pip install --extra-index-url https://download.pytorch.org/whl/cu124 \ + torch==2.6.0+cu124 torchvision==0.21.0+cu124 + +# Install SageAttention from git (patch GPU detection) +ENV TORCH_CUDA_ARCH_LIST="${CUDA_ARCHITECTURES}" +ENV FORCE_CUDA="1" +ENV MAX_JOBS="1" + +COPY < Control Image Process = Perform Inpainting & Area Processed = Masked Area > Upload a Control Image, then draw your mask directly on top of the image & enter a text Prompt that describes the expected change > Generate > Below the Video Gallery click 'To Control Image' > Keep on doing more changes*. -*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* +Doing more sophisticated thing Vace Image Editor works very well too: try Image Outpainting, Pose transfer, ... +For the best quality I recommend to set in *Quality Tab* the option: "*Generate a 9 Frames Long video...*" -### August 4 2025: WanGP v7.6 - Remuxed +**update 8.55**: Flux Festival +- **Inpainting Mode** also added for *Flux Kontext* +- **Flux SRPO** : new finetune with x3 better quality vs Flux Dev according to its authors. I have also created a *Flux SRPO USO* finetune which is certainly the best open source *Style Transfer* tool available +- **Flux UMO**: model specialized in combining multiple reference objects / people together. Works quite well at 768x768 -With this new version you won't have any excuse if there is no sound in your video. +Good luck with finding your way through all the Flux models names ! -*Continue Video* now works with any video that has already some sound (hint: Multitalk ). +### September 5 2025: WanGP v8.4 - Take me to Outer Space +You have probably seen these short AI generated movies created using *Nano Banana* and the *First Frame - Last Frame* feature of *Kling 2.0*. The idea is to generate an image, modify a part of it with Nano Banana and give the these two images to Kling that will generate the Video between these two images, use now the previous Last Frame as the new First Frame, rinse and repeat and you get a full movie. -Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. +I have made it easier to do just that with *Qwen Edit* and *Wan*: +- **End Frames can now be combined with Continue a Video** (and not just a Start Frame) +- **Multiple End Frames can be inputed**, each End Frame will be used for a different Sliding Window -As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. +You can plan in advance all your shots (one shot = one Sliding Window) : I recommend using Wan 2.2 Image to Image with multiple End Frames (one for each shot / Sliding Window), and a different Text Prompt for each shot / Sliding Winow (remember to enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*) -For instance: -- first video part: use Multitalk with two people speaking -- second video part: you apply your own soundtrack which will gently follow the multitalk conversation -- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio +The results can quite be impressive. However, Wan 2.1 & 2.2 Image 2 Image are restricted to a single overlap frame when using Slide Windows, which means only one frame is reeused for the motion. This may be unsufficient if you are trying to connect two shots with fast movement. -To multiply the combinations I have also implemented *Continue Video* with the various image2video models. +This is where *InfinitTalk* comes into play. Beside being one best models to generate animated audio driven avatars, InfiniteTalk uses internally more one than motion frames. It is quite good to maintain the motions between two shots. I have tweaked InfinitTalk so that **its motion engine can be used even if no audio is provided**. +So here is how to use InfiniteTalk: enable *Sliding Windows/Text Prompts Will be used for a new Sliding Window of the same Video Generation*), and if you continue an existing Video *Misc/Override Frames per Second" should be set to "Source Video*. Each Reference Frame inputed will play the same role as the End Frame except it wont be exactly an End Frame (it will correspond more to a middle frame, the actual End Frame will differ but will be close) -Also: -- End Frame support added for LTX Video models -- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides -- Flux Krea Dev support -### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 -Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... +You will find below a 33s movie I have created using these two methods. Quality could be much better as I havent tuned at all the settings (I couldn't bother, I used 10 steps generation without Loras Accelerators for most of the gens). -Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder. +### September 2 2025: WanGP v8.31 - At last the pain stops -I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference... +- This single new feature should give you the strength to face all the potential bugs of this new release: +**Images Management (multiple additions or deletions, reordering) for Start Images / End Images / Images References.** -And this time I really removed Vace Cocktail Light which gave a blurry vision. +- Unofficial **Video to Video (Non Sparse this time) for InfinitTalk**. Use the Strength Noise slider to decide how much motion of the original window you want to keep. I have also *greatly reduced the VRAM requirements for Multitalk / Infinitalk* (especially the multispeakers version & when generating at 1080p). -### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview -Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. +- **Experimental Sage 3 Attention support**: you will need to deserve this one, first you need a Blackwell GPU (RTX50xx) and request an access to Sage 3 Github repo, then you will have to compile Sage 3, install it and cross your fingers ... -So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. -However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** ! +*update 8.31: one shouldnt talk about bugs if one doesn't want to attract bugs* -Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation - -Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... - -7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well. - -### July 27 2025: WanGP v7.3 : Interlude -While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. - -### July 26 2025: WanGP v7.2 : Ode to Vace -I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk. - -Here are some new Vace improvements: -- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX. -- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). -- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged. -- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code. - -Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats. - - -### July 21 2025: WanGP v7.12 -- Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. - -- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App) - -- LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. - -- LTX IC-Lora support: these are special Loras that consumes a conditional image or video -Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. - -- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update) - -- Easier way to select video resolution - - -### July 15 2025: WanGP v7.0 is an AI Powered Photoshop -This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : -- Multiple Images generated at the same time so that you can choose the one you like best.It is Highly VRAM optimized so that you can generate for instance 4 720p Images at the same time with less than 10 GB -- With the *image2image* the original text2video WanGP becomes an image upsampler / restorer -- *Vace image2image* comes out of the box with image outpainting, person / object replacement, ... -- You can use in one click a newly Image generated as Start Image or Reference Image for a Video generation - -And to complete the full suite of AI Image Generators, Ladies and Gentlemen please welcome for the first time in WanGP : **Flux Kontext**.\ -As a reminder Flux Kontext is an image editor : give it an image and a prompt and it will do the change for you.\ -This highly optimized version of Flux Kontext will make you feel that you have been cheated all this time as WanGP Flux Kontext requires only 8 GB of VRAM to generate 4 images at the same time with no need for quantization. - -WanGP v7 comes with *Image2image* vanilla and *Vace FusinoniX*. However you can build your own finetune where you will combine a text2video or Vace model with any combination of Loras. - -Also in the news: -- You can now enter the *Bbox* for each speaker in *Multitalk* to precisely locate who is speaking. And to save some headaches the *Image Mask generator* will give you the *Bbox* coordinates of an area you have selected. -- *Film Grain* post processing to add a vintage look at your video -- *First Last Frame to Video* model should work much better now as I have discovered rencently its implementation was not complete -- More power for the finetuners, you can now embed Loras directly in the finetune definition. You can also override the default models (titles, visibility, ...) with your own finetunes. Check the doc that has been updated. - - -### July 10 2025: WanGP v6.7, is NAG a game changer ? you tell me -Maybe you knew that already but most *Loras accelerators* we use today (Causvid, FusioniX) don't use *Guidance* at all (that it is *CFG* is set to 1). This helps to get much faster generations but the downside is that *Negative Prompts* are completely ignored (including the default ones set by the models). **NAG** (https://github.com/ChenDarYen/Normalized-Attention-Guidance) aims to solve that by injecting the *Negative Prompt* during the *attention* processing phase. - -So WanGP 6.7 gives you NAG, but not any NAG, a *Low VRAM* implementation, the default one ends being VRAM greedy. You will find NAG in the *General* advanced tab for most Wan models. - -Use NAG especially when Guidance is set to 1. To turn it on set the **NAG scale** to something around 10. There are other NAG parameters **NAG tau** and **NAG alpha** which I recommend to change only if you don't get good results by just playing with the NAG scale. Don't hesitate to share on this discord server the best combinations for these 3 parameters. - -The authors of NAG claim that NAG can also be used when using a Guidance (CFG > 1) and to improve the prompt adherence. - -### July 8 2025: WanGP v6.6, WanGP offers you **Vace Multitalk Dual Voices Fusionix Infinite** : -**Vace** our beloved super Control Net has been combined with **Multitalk** the new king in town that can animate up to two people speaking (**Dual Voices**). It is accelerated by the **Fusionix** model and thanks to *Sliding Windows* support and *Adaptive Projected Guidance* (much slower but should reduce the reddish effect with long videos) your two people will be able to talk for very a long time (which is an **Infinite** amount of time in the field of video generation). - -Of course you will get as well *Multitalk* vanilla and also *Multitalk 720p* as a bonus. - -And since I am mister nice guy I have enclosed as an exclusivity an *Audio Separator* that will save you time to isolate each voice when using Multitalk with two people. - -As I feel like resting a bit I haven't produced yet a nice sample Video to illustrate all these new capabilities. But here is the thing, I ams sure you will publish in the *Share Your Best Video* channel your *Master Pieces*. The best ones will be added to the *Announcements Channel* and will bring eternal fame to its authors. - -But wait, there is more: -- Sliding Windows support has been added anywhere with Wan models, so imagine with text2video recently upgraded in 6.5 into a video2video, you can now upsample very long videos regardless of your VRAM. The good old image2video model can now reuse the last image to produce new videos (as requested by many of you) -- I have added also the capability to transfer the audio of the original control video (Misc. advanced tab) and an option to preserve the fps into the generated video, so from now on you will be to upsample / restore your old families video and keep the audio at their original pace. Be aware that the duration will be limited to 1000 frames as I still need to add streaming support for unlimited video sizes. - -Also, of interest too: -- Extract video info from Videos that have not been generated by WanGP, even better you can also apply post processing (Upsampling / MMAudio) on non WanGP videos -- Force the generated video fps to your liking, works wery well with Vace when using a Control Video -- Ability to chain URLs of Finetune models (for instance put the URLs of a model in your main finetune and reference this finetune in other finetune models to save time) - -### July 2 2025: WanGP v6.5.1, WanGP takes care of you: lots of quality of life features: -- View directly inside WanGP the properties (seed, resolutions, length, most settings...) of the past generations -- In one click use the newly generated video as a Control Video or Source Video to be continued -- Manage multiple settings for the same model and switch between them using a dropdown box -- WanGP will keep the last generated videos in the Gallery and will remember the last model you used if you restart the app but kept the Web page open -- Custom resolutions : add a file in the WanGP folder with the list of resolutions you want to see in WanGP (look at the instruction readme in this folder) - -Taking care of your life is not enough, you want new stuff to play with ? -- MMAudio directly inside WanGP : add an audio soundtrack that matches the content of your video. By the way it is a low VRAM MMAudio and 6 GB of VRAM should be sufficient. You will need to go in the *Extensions* tab of the WanGP *Configuration* to enable MMAudio -- Forgot to upsample your video during the generation ? want to try another MMAudio variation ? Fear not you can also apply upsampling or add an MMAudio track once the video generation is done. Even better you can ask WangGP for multiple variations of MMAudio to pick the one you like best -- MagCache support: a new step skipping approach, supposed to be better than TeaCache. Makes a difference if you usually generate with a high number of steps -- SageAttention2++ support : not just the compatibility but also a slightly reduced VRAM usage -- Video2Video in Wan Text2Video : this is the paradox, a text2video can become a video2video if you start the denoising process later on an existing video -- FusioniX upsampler: this is an illustration of Video2Video in Text2Video. Use the FusioniX text2video model with an output resolution of 1080p and a denoising strength of 0.25 and you will get one of the best upsamplers (in only 2/3 steps, you will need lots of VRAM though). Increase the denoising strength and you will get one of the best Video Restorer -- Choice of Wan Samplers / Schedulers -- More Lora formats support - -**If you had upgraded to v6.5 please upgrade again to 6.5.1 as this will fix a bug that ignored Loras beyond the first one** See full changelog: **[Changelog](docs/CHANGELOG.md)** @@ -248,7 +135,9 @@ See full changelog: **[Changelog](docs/CHANGELOG.md)** ## 🚀 Quick Start -**One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/) +**One-click installation:** +- Get started instantly with [Pinokio App](https://pinokio.computer/) +- Use Redtash1 [One Click Install with Sage](https://github.com/Redtash1/Wan2GP-Windows-One-Click-Install-With-Sage) **Manual installation:** ```bash @@ -262,8 +151,7 @@ pip install -r requirements.txt **Run the application:** ```bash -python wgp.py # Text-to-video (default) -python wgp.py --i2v # Image-to-video +python wgp.py ``` **Update the application:** @@ -274,6 +162,32 @@ git pull pip install -r requirements.txt ``` +## 🐳 Docker: + +**For Debian-based systems (Ubuntu, Debian, etc.):** + +```bash +./run-docker-cuda-deb.sh +``` + +This automated script will: + +- Detect your GPU model and VRAM automatically +- Select optimal CUDA architecture for your GPU +- Install NVIDIA Docker runtime if needed +- Build a Docker image with all dependencies +- Run WanGP with optimal settings for your hardware + +**Docker environment includes:** + +- NVIDIA CUDA 12.4.1 with cuDNN support +- PyTorch 2.6.0 with CUDA 12.4 support +- SageAttention compiled for your specific GPU architecture +- Optimized environment variables for performance (TF32, threading, etc.) +- Automatic cache directory mounting for faster subsequent runs +- Current directory mounted in container - all downloaded models, loras, generated videos and files are saved locally + +**Supported GPUs:** RTX 40XX, RTX 30XX, RTX 20XX, GTX 16XX, GTX 10XX, Tesla V100, A100, H100, and more. ## 📦 Installation @@ -317,4 +231,4 @@ https://www.youtube.com/watch?v=T5jNiEhf9xk

Made with ❤️ by DeepBeepMeep -

+

diff --git a/configs/animate.json b/configs/animate.json new file mode 100644 index 000000000..7e98ca95e --- /dev/null +++ b/configs/animate.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 36, + "model_type": "i2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "motion_encoder_dim": 512 +} \ No newline at end of file diff --git a/configs/lucy_edit.json b/configs/lucy_edit.json new file mode 100644 index 000000000..4983cedd1 --- /dev/null +++ b/configs/lucy_edit.json @@ -0,0 +1,14 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.33.0", + "dim": 3072, + "eps": 1e-06, + "ffn_dim": 14336, + "freq_dim": 256, + "in_dim": 96, + "model_type": "ti2v2_2", + "num_heads": 24, + "num_layers": 30, + "out_dim": 48, + "text_len": 512 +} diff --git a/configs/lynx.json b/configs/lynx.json new file mode 100644 index 000000000..a973d5e68 --- /dev/null +++ b/configs/lynx.json @@ -0,0 +1,15 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "lynx": "full" +} diff --git a/configs/vace_lynx_14B.json b/configs/vace_lynx_14B.json new file mode 100644 index 000000000..770cbeef6 --- /dev/null +++ b/configs/vace_lynx_14B.json @@ -0,0 +1,17 @@ +{ + "_class_name": "WanModel", + "_diffusers_version": "0.30.0", + "dim": 5120, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_dim": 16, + "model_type": "t2v", + "num_heads": 40, + "num_layers": 40, + "out_dim": 16, + "text_len": 512, + "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], + "vace_in_dim": 96, + "lynx": "full" +} diff --git a/defaults/animate.json b/defaults/animate.json new file mode 100644 index 000000000..bf45f4a92 --- /dev/null +++ b/defaults/animate.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Wan2.2 Animate", + "architecture": "animate", + "description": "Wan-Animate takes a video and a character image as input, and generates a video in either 'Animation' or 'Replacement' mode. Sliding Window of 81 frames at least are recommeded to obtain the best Style continuity.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_fp16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_14B_quanto_bf16_int8.safetensors" + ], + "preload_URLs" : + [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_animate_relighting_lora.safetensors" + ], + "group": "wan2_2" + } +} \ No newline at end of file diff --git a/defaults/flux_dev_kontext.json b/defaults/flux_dev_kontext.json index 894591880..20b6bc4d3 100644 --- a/defaults/flux_dev_kontext.json +++ b/defaults/flux_dev_kontext.json @@ -7,8 +7,6 @@ "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1_kontext_dev_quanto_bf16_int8.safetensors" ], - "image_outputs": true, - "reference_image": true, "flux-model": "flux-dev-kontext" }, "prompt": "add a hat", diff --git a/defaults/flux_dev_umo.json b/defaults/flux_dev_umo.json new file mode 100644 index 000000000..57164bb4c --- /dev/null +++ b/defaults/flux_dev_umo.json @@ -0,0 +1,24 @@ +{ + "model": { + "name": "Flux 1 Dev UMO 12B", + "architecture": "flux", + "description": "FLUX.1 Dev UMO is a model that can Edit Images with a specialization in combining multiple image references (resized internally at 512x512 max) to produce an Image output. Best Image preservation at 768x768 Resolution Output.", + "URLs": "flux", + "flux-model": "flux-dev-umo", + "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-UMO_dit_lora_bf16.safetensors"], + "resolutions": [ ["1024x1024 (1:1)", "1024x1024"], + ["768x1024 (3:4)", "768x1024"], + ["1024x768 (4:3)", "1024x768"], + ["512x1024 (1:2)", "512x1024"], + ["1024x512 (2:1)", "1024x512"], + ["768x768 (1:1)", "768x768"], + ["768x512 (3:2)", "768x512"], + ["512x768 (2:3)", "512x768"]] + }, + "prompt": "the man is wearing a hat", + "embedded_guidance_scale": 4, + "resolution": "768x768", + "batch_size": 1 +} + + \ No newline at end of file diff --git a/defaults/flux_dev_uso.json b/defaults/flux_dev_uso.json index ab5ac5459..806dd7e02 100644 --- a/defaults/flux_dev_uso.json +++ b/defaults/flux_dev_uso.json @@ -2,15 +2,13 @@ "model": { "name": "Flux 1 Dev USO 12B", "architecture": "flux", - "description": "FLUX.1 Dev USO is a model specialized to Edit Images with a specialization in Style Transfers (up to two).", + "description": "FLUX.1 Dev USO is a model that can Edit Images with a specialization in Style Transfers (up to two).", "modules": [ ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_projector_bf16.safetensors"]], "URLs": "flux", "loras": ["https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-dev-USO_dit_lora_bf16.safetensors"], - "image_outputs": true, - "reference_image": true, "flux-model": "flux-dev-uso" }, - "prompt": "add a hat", + "prompt": "the man is wearing a hat", "embedded_guidance_scale": 4, "resolution": "1024x1024", "batch_size": 1 diff --git a/defaults/flux_srpo.json b/defaults/flux_srpo.json new file mode 100644 index 000000000..59f07c617 --- /dev/null +++ b/defaults/flux_srpo.json @@ -0,0 +1,15 @@ +{ + "model": { + "name": "Flux 1 SRPO Dev 12B", + "architecture": "flux", + "description": "By fine-tuning the FLUX.1.dev model with optimized denoising and online reward adjustment, SRPO improves its human-evaluated realism and aesthetic quality by over 3x.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Flux/resolve/main/flux1-srpo-dev_quanto_bf16_int8.safetensors" + ], + "flux-model": "flux-dev" + }, + "prompt": "draw a hat", + "resolution": "1024x1024", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/flux_srpo_uso.json b/defaults/flux_srpo_uso.json new file mode 100644 index 000000000..ddfe50d63 --- /dev/null +++ b/defaults/flux_srpo_uso.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Flux 1 SRPO USO 12B", + "architecture": "flux", + "description": "FLUX.1 SRPO USO is a model that can Edit Images with a specialization in Style Transfers (up to two). It leverages the improved Image quality brought by the SRPO process", + "modules": [ "flux_dev_uso"], + "URLs": "flux_srpo", + "loras": "flux_dev_uso", + "flux-model": "flux-dev-uso" + }, + "prompt": "the man is wearing a hat", + "embedded_guidance_scale": 4, + "resolution": "1024x1024", + "batch_size": 1 +} + + \ No newline at end of file diff --git a/defaults/lucy_edit.json b/defaults/lucy_edit.json new file mode 100644 index 000000000..a8f67ad64 --- /dev/null +++ b/defaults/lucy_edit.json @@ -0,0 +1,19 @@ +{ + "model": { + "name": "Wan2.2 Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/wan2.2_lucy_edit_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "prompt": "change the clothes to red", + "video_length": 81, + "guidance_scale": 5, + "flow_shift": 5, + "num_inference_steps": 30, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/lucy_edit_fastwan.json b/defaults/lucy_edit_fastwan.json new file mode 100644 index 000000000..d5d47c814 --- /dev/null +++ b/defaults/lucy_edit_fastwan.json @@ -0,0 +1,16 @@ +{ + "model": { + "name": "Wan2.2 FastWan Lucy Edit 5B", + "architecture": "lucy_edit", + "description": "Lucy Edit is a video editing model that performs instruction-guided edits on videos using free-text prompts. It supports a variety of edits, such as clothing & accessory changes, character changes, object insertions, and scene replacements while preserving the motion and composition perfectly. This is the FastWan version for faster generation.", + "URLs": "lucy_edit", + "group": "wan2_2", + "loras": "ti2v_2_2_fastwan" + }, + "prompt": "change the clothes to red", + "video_length": 81, + "guidance_scale": 1, + "flow_shift": 3, + "num_inference_steps": 5, + "resolution": "1280x720" +} \ No newline at end of file diff --git a/defaults/lynx.json b/defaults/lynx.json new file mode 100644 index 000000000..528f5ef68 --- /dev/null +++ b/defaults/lynx.json @@ -0,0 +1,18 @@ +{ + "model": { + "name": "Wan2.1 Lynx 14B", + "modules": [ + [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_bf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_quanto_bf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_module_14B_quanto_fp16_int8.safetensors" + ] + ], + "architecture": "lynx", + "description": "The Lynx ControlNet offers State of the Art Identity Preservation. You need to provide a Reference Image which is a close up of a person face to transfer this person in the Video.", + "URLs": "t2v", + "preload_URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/wan2.1_lynx_full_arc_resampler.safetensors" + ] + } +} \ No newline at end of file diff --git a/defaults/qwen_image_edit_20B.json b/defaults/qwen_image_edit_20B.json index 2b24c72da..04fc573fb 100644 --- a/defaults/qwen_image_edit_20B.json +++ b/defaults/qwen_image_edit_20B.json @@ -7,11 +7,10 @@ "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_bf16.safetensors", "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_20B_quanto_bf16_int8.safetensors" ], + "preload_URLs": ["https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_inpainting.safetensors"], "attention": { "<89": "sdpa" - }, - "reference_image": true, - "image_outputs": true + } }, "prompt": "add a hat", "resolution": "1280x720", diff --git a/defaults/qwen_image_edit_plus_20B.json b/defaults/qwen_image_edit_plus_20B.json new file mode 100644 index 000000000..e10deb24b --- /dev/null +++ b/defaults/qwen_image_edit_plus_20B.json @@ -0,0 +1,17 @@ +{ + "model": { + "name": "Qwen Image Edit Plus 20B", + "architecture": "qwen_image_edit_plus_20B", + "description": "Qwen Image Edit Plus is a generative model that can generate very high quality images with long texts in it. Best results will be at 720p. This model is optimized to combine multiple Subjects & Objects.", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Qwen_image/resolve/main/qwen_image_edit_plus_20B_quanto_bf16_int8.safetensors" + ], + "preload_URLs": "qwen_image_edit_20B", + "attention": { + "<89": "sdpa" + } + }, + "prompt": "add a hat", + "resolution": "1024x1024", + "batch_size": 1 +} \ No newline at end of file diff --git a/defaults/standin.json b/defaults/standin.json index 1b5e32479..09298e97a 100644 --- a/defaults/standin.json +++ b/defaults/standin.json @@ -4,7 +4,7 @@ "name": "Wan2.1 Standin 14B", "modules": [ ["https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/Stand-In_wan2.1_T2V_14B_ver1.0_bf16.safetensors"]], "architecture" : "standin", - "description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of person face to transfer this person in the Video.", + "description": "The original Wan Text 2 Video model combined with the StandIn module to improve Identity Preservation. You need to provide a Reference Image with white background which is a close up of a person face to transfer this person in the Video.", "URLs": "t2v" } } \ No newline at end of file diff --git a/defaults/ti2v_2_2_fastwan.json b/defaults/ti2v_2_2_fastwan.json index 064c2b4ba..fa69f82ec 100644 --- a/defaults/ti2v_2_2_fastwan.json +++ b/defaults/ti2v_2_2_fastwan.json @@ -7,6 +7,7 @@ "loras": ["https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/loras_accelerators/Wan2_2_5B_FastWanFullAttn_lora_rank_128_bf16.safetensors"], "group": "wan2_2" }, + "prompt" : "Put the person into a clown outfit.", "video_length": 121, "guidance_scale": 1, "flow_shift": 3, diff --git a/defaults/vace_fun_14B_2_2.json b/defaults/vace_fun_14B_2_2.json new file mode 100644 index 000000000..8f22d3456 --- /dev/null +++ b/defaults/vace_fun_14B_2_2.json @@ -0,0 +1,24 @@ +{ + "model": { + "name": "Wan2.2 Vace Fun 14B", + "architecture": "vace_14B", + "description": "This is the Fun Vace 2.2 version, that is not the official Vace 2.2", + "URLs": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_HIGH_quanto_mfp16_int8.safetensors" + ], + "URLs2": [ + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_mbf16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mbf16_int8.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.2/resolve/main/Wan2_2_Fun_VACE_A14B_LOW_quanto_mfp16_int8.safetensors" + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "num_inference_steps": 30, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold": 875 +} \ No newline at end of file diff --git a/defaults/vace_fun_14B_cocktail_2_2.json b/defaults/vace_fun_14B_cocktail_2_2.json new file mode 100644 index 000000000..c587abdd4 --- /dev/null +++ b/defaults/vace_fun_14B_cocktail_2_2.json @@ -0,0 +1,28 @@ +{ + "model": { + "name": "Wan2.2 Vace Fun Cocktail 14B", + "architecture": "vace_14B", + "description": "This model has been created on the fly using the Wan text 2.2 video model and the Loras of FusioniX. The weight of the Detail Enhancer Lora has been reduced to improve identity preservation. This is the Fun Vace 2.2, that is not the official Vace 2.2", + "URLs": "vace_fun_14B_2_2", + "URLs2": "vace_fun_14B_2_2", + "loras": [ + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_CausVid_14B_T2V_lora_rank32_v2.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/DetailEnhancerV1.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors", + "https://huggingface.co/DeepBeepMeep/Wan2.1/resolve/main/loras_accelerators/Wan21_T2V_14B_MoviiGen_lora_rank32_fp16.safetensors" + ], + "loras_multipliers": [ + 1, + 0.2, + 0.5, + 0.5 + ], + "group": "wan2_2" + }, + "guidance_phases": 2, + "num_inference_steps": 10, + "guidance_scale": 1, + "guidance2_scale": 1, + "flow_shift": 2, + "switch_threshold": 875 +} \ No newline at end of file diff --git a/defaults/vace_lynx_14B.json b/defaults/vace_lynx_14B.json new file mode 100644 index 000000000..523d5b3fc --- /dev/null +++ b/defaults/vace_lynx_14B.json @@ -0,0 +1,10 @@ +{ + "model": { + "name": "Vace Lynx 14B", + "architecture": "vace_lynx_14B", + "modules": [ "vace_14B", "lynx"], + "description": "The Vace ControlNet model is a powerful model that allows you to control the content of the generated video based of additional custom data : pose or depth video, images or objects you want to see in the video. The Lynx version is specialized in transferring a Person Identity.", + "URLs": "t2v", + "preload_URLs": "lynx" + } +} \ No newline at end of file diff --git a/docs/AMD-INSTALLATION.md b/docs/AMD-INSTALLATION.md new file mode 100644 index 000000000..4f05589eb --- /dev/null +++ b/docs/AMD-INSTALLATION.md @@ -0,0 +1,146 @@ +# Installation Guide + +This guide covers installation for specific RDNA3 and RDNA3.5 AMD CPUs (APUs) and GPUs +running under Windows. + +tl;dr: Radeon RX 7900 GOOD, RX 9700 BAD, RX 6800 BAD. (I know, life isn't fair). + +Currently supported (but not necessary tested): + +**gfx110x**: + +* Radeon RX 7600 +* Radeon RX 7700 XT +* Radeon RX 7800 XT +* Radeon RX 7900 GRE +* Radeon RX 7900 XT +* Radeon RX 7900 XTX + +**gfx1151**: + +* Ryzen 7000 series APUs (Phoenix) +* Ryzen Z1 (e.g., handheld devices like the ROG Ally) + +**gfx1201**: + +* Ryzen 8000 series APUs (Strix Point) +* A [frame.work](https://frame.work/au/en/desktop) desktop/laptop + + +## Requirements + +- Python 3.11 (3.12 might work, 3.10 definately will not!) + +## Installation Environment + +This installation uses PyTorch 2.7.0 because that's what currently available in +terms of pre-compiled wheels. + +### Installing Python + +Download Python 3.11 from [python.org/downloads/windows](https://www.python.org/downloads/windows/). Hit Ctrl+F and search for "3.11". Dont use this direct link: [https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe](https://www.python.org/ftp/python/3.11.9/python-3.11.9-amd64.exe) -- that was an IQ test. + +After installing, make sure `python --version` works in your terminal and returns 3.11.x + +If not, you probably need to fix your PATH. Go to: + +* Windows + Pause/Break +* Advanced System Settings +* Environment Variables +* Edit your `Path` under User Variables + +Example correct entries: + +```cmd +C:\Users\YOURNAME\AppData\Local\Programs\Python\Launcher\ +C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\Scripts\ +C:\Users\YOURNAME\AppData\Local\Programs\Python\Python311\ +``` + +If that doesnt work, scream into a bucket. + +### Installing Git + +Get Git from [git-scm.com/downloads/win](https://git-scm.com/downloads/win). Default install is fine. + + +## Install (Windows, using `venv`) + +### Step 1: Download and Set Up Environment + +```cmd +:: Navigate to your desired install directory +cd \your-path-to-wan2gp + +:: Clone the repository +git clone https://github.com/deepbeepmeep/Wan2GP.git +cd Wan2GP + +:: Create virtual environment using Python 3.10.9 +python -m venv wan2gp-env + +:: Activate the virtual environment +wan2gp-env\Scripts\activate +``` + +### Step 2: Install PyTorch + +The pre-compiled wheels you need are hosted at [scottt's rocm-TheRock releases](https://github.com/scottt/rocm-TheRock/releases). Find the heading that says: + +**Pytorch wheels for gfx110x, gfx1151, and gfx1201** + +Don't click this link: [https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x](https://github.com/scottt/rocm-TheRock/releases/tag/v6.5.0rc-pytorch-gfx110x). It's just here to check if you're skimming. + +Copy the links of the closest binaries to the ones in the example below (adjust if you're not running Python 3.11), then hit enter. + +```cmd +pip install ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torch-2.7.0a0+rocm_git3f903c3-cp311-cp311-win_amd64.whl ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchaudio-2.7.0a0+52638ef-cp311-cp311-win_amd64.whl ^ + https://github.com/scottt/rocm-TheRock/releases/download/v6.5.0rc-pytorch-gfx110x/torchvision-0.22.0+9eb57cd-cp311-cp311-win_amd64.whl +``` + +### Step 3: Install Dependencies + +```cmd +:: Install core dependencies +pip install -r requirements.txt +``` + +## Attention Modes + +WanGP supports several attention implementations, only one of which will work for you: + +- **SDPA** (default): Available by default with PyTorch. This uses the built-in aotriton accel library, so is actually pretty fast. + +## Performance Profiles + +Choose a profile based on your hardware: + +- **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model +- **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement + +## Running Wan2GP + +In future, you will have to do this: + +```cmd +cd \path-to\wan2gp +wan2gp\Scripts\activate.bat +python wgp.py +``` + +For now, you should just be able to type `python wgp.py` (because you're already in the virtual environment) + +## Troubleshooting + +- If you use a HIGH VRAM mode, don't be a fool. Make sure you use VAE Tiled Decoding. + +### Memory Issues + +- Use lower resolution or shorter videos +- Enable quantization (default) +- Use Profile 4 for lower VRAM usage +- Consider using 1.3B models instead of 14B models + +For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 5a89d9389..b0eeae34b 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,20 +1,154 @@ # Changelog ## 🔥 Latest News -### July 21 2025: WanGP v7.1 +### August 29 2025: WanGP v8.21 - Here Goes Your Weekend + +- **InfiniteTalk Video to Video**: this feature can be used for Video Dubbing. Keep in mind that it is a *Sparse Video to Video*, that is internally only image is used by Sliding Window. However thanks to the new *Smooth Transition* mode, each new clip is connected to the previous and all the camera work is done by InfiniteTalk. If you dont get any transition, increase the number of frames of a Sliding Window (81 frames recommended) + +- **StandIn**: very light model specialized in Identity Transfer. I have provided two versions of Standin: a basic one derived from the text 2 video model and another based on Vace. If used with Vace, the last reference frame given to Vace will be also used for StandIn + +- **Flux ESO**: a new Flux dervied *Image Editing tool*, but this one is specialized both in *Identity Transfer* and *Style Transfer*. Style has to be understood in its wide meaning: give a reference picture of a person and another one of Sushis and you will turn this person into Sushis + +### August 24 2025: WanGP v8.1 - the RAM Liberator + +- **Reserved RAM entirely freed when switching models**, you should get much less out of memory related to RAM. I have also added a button in *Configuration / Performance* that will release most of the RAM used by WanGP if you want to use another application without quitting WanGP +- **InfiniteTalk** support: improved version of Multitalk that supposedly supports very long video generations based on an audio track. Exists in two flavors (*Single Speaker* and *Multi Speakers*) but doesnt seem to be compatible with Vace. One key new feature compared to Multitalk is that you can have different visual shots associated to the same audio: each Reference frame you provide you will be associated to a new Sliding Window. If only Reference frame is provided, it will be used for all windows. When Continuing a video, you can either continue the current shot (no Reference Frame) or add new shots (one or more Reference Frames).\ +If you are not into audio, you can use still this model to generate infinite long image2video, just select "no speaker". Last but not least, Infinitetalk works works with all the Loras accelerators. +- **Flux Chroma 1 HD** support: uncensored flux based model and lighter than Flux (8.9B versus 12B) and can fit entirely in VRAM with only 16 GB of VRAM. Unfortunalely it is not distilled and you will need CFG at minimum 20 steps + +### August 21 2025: WanGP v8.01 - the killer of seven + +- **Qwen Image Edit** : Flux Kontext challenger (prompt driven image edition). Best results (including Identity preservation) will be obtained at 720p. Beyond you may get image outpainting and / or lose identity preservation. Below 720p prompt adherence will be worse. Qwen Image Edit works with Qwen Lora Lightning 4 steps. I have also unlocked all the resolutions for Qwen models. Bonus Zone: support for multiple image compositions but identity preservation won't be as good. +- **On demand Prompt Enhancer** (needs to be enabled in Configuration Tab) that you can use to Enhance a Text Prompt before starting a Generation. You can refine the Enhanced Prompt or change the original Prompt. +- Choice of a **Non censored Prompt Enhancer**. Beware this is one is VRAM hungry and will require 12 GB of VRAM to work +- **Memory Profile customizable per model** : useful to set for instance Profile 3 (preload the model entirely in VRAM) with only Image Generation models, if you have 24 GB of VRAM. In that case Generation will be much faster because with Image generators (contrary to Video generators) as a lot of time is wasted in offloading +- **Expert Guidance Mode**: change the Guidance during the generation up to 2 times. Very useful with Wan 2.2 Ligthning to reduce the slow motion effect. The idea is to insert a CFG phase before the 2 accelerated phases that follow and have no Guidance. I have added the finetune *Wan2.2 Vace Lightning 3 Phases 14B* with a prebuilt configuration. Please note that it is a 8 steps process although the lora lightning is 4 steps. This expert guidance mode is also available with Wan 2.1. + +*WanGP 8.01 update, improved Qwen Image Edit Identity Preservation* +### August 12 2025: WanGP v7.7777 - Lucky Day(s) + +This is your lucky day ! thanks to new configuration options that will let you store generated Videos and Images in lossless compressed formats, you will find they in fact they look two times better without doing anything ! + +Just kidding, they will be only marginally better, but at least this opens the way to professionnal editing. + +Support: +- Video: x264, x264 lossless, x265 +- Images: jpeg, png, webp, wbp lossless +Generation Settings are stored in each of the above regardless of the format (that was the hard part). + +Also you can now choose different output directories for images and videos. + +unexpected luck: fixed lightning 8 steps for Qwen, and lightning 4 steps for Wan 2.2, now you just need 1x multiplier no weird numbers. +*update 7.777 : oops got a crash a with FastWan ? Luck comes and goes, try a new update, maybe you will have a better chance this time* +*update 7.7777 : Sometime good luck seems to last forever. For instance what if Qwen Lightning 4 steps could also work with WanGP ?* +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors (Qwen Lightning 4 steps) +- https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors (new improved version of Qwen Lightning 8 steps) + + +### August 10 2025: WanGP v7.76 - Faster than the VAE ... +We have a funny one here today: FastWan 2.2 5B, the Fastest Video Generator, only 20s to generate 121 frames at 720p. The snag is that VAE is twice as slow... +Thanks to Kijai for extracting the Lora that is used to build the corresponding finetune. + +*WanGP 7.76: fixed the messed up I did to i2v models (loras path was wrong for Wan2.2 and Clip broken)* + +### August 9 2025: WanGP v7.74 - Qwen Rebirth part 2 +Added support for Qwen Lightning lora for a 8 steps generation (https://huggingface.co/lightx2v/Qwen-Image-Lightning/blob/main/Qwen-Image-Lightning-8steps-V1.0.safetensors). Lora is not normalized and you can use a multiplier around 0.1. + +Mag Cache support for all the Wan2.2 models Don't forget to set guidance to 1 and 8 denoising steps , your gen will be 7x faster ! + +### August 8 2025: WanGP v7.73 - Qwen Rebirth +Ever wondered what impact not using Guidance has on a model that expects it ? Just look at Qween Image in WanGP 7.71 whose outputs were erratic. Somehow I had convinced myself that Qwen was a distilled model. In fact Qwen was dying for a negative prompt. And in WanGP 7.72 there is at last one for him. + +As Qwen is not so picky after all I have added also quantized text encoder which reduces the RAM requirements of Qwen by 10 GB (the text encoder quantized version produced garbage before) + +Unfortunately still the Sage bug for older GPU architectures. Added Sdpa fallback for these architectures. + +*7.73 update: still Sage / Sage2 bug for GPUs before RTX40xx. I have added a detection mechanism that forces Sdpa attention if that's the case* + + +### August 6 2025: WanGP v7.71 - Picky, picky + +This release comes with two new models : +- Qwen Image: a Commercial grade Image generator capable to inject full sentences in the generated Image while still offering incredible visuals +- Wan 2.2 TextImage to Video 5B: the last Wan 2.2 needed if you want to complete your Wan 2.2 collection (loras for this folder can be stored in "\loras\5B" ) + +There is catch though, they are very picky if you want to get good generations: first they both need lots of steps (50 ?) to show what they have to offer. Then for Qwen Image I had to hardcode the supported resolutions, because if you try anything else, you will get garbage. Likewise Wan 2.2 5B will remind you of Wan 1.0 if you don't ask for at least 720p. + +*7.71 update: Added VAE Tiling for both Qwen Image and Wan 2.2 TextImage to Video 5B, for low VRAM during a whole gen.* + + +### August 4 2025: WanGP v7.6 - Remuxed + +With this new version you won't have any excuse if there is no sound in your video. + +*Continue Video* now works with any video that has already some sound (hint: Multitalk ). + +Also, on top of MMaudio and the various sound driven models I have added the ability to use your own soundtrack. + +As a result you can apply a different sound source on each new video segment when doing a *Continue Video*. + +For instance: +- first video part: use Multitalk with two people speaking +- second video part: you apply your own soundtrack which will gently follow the multitalk conversation +- third video part: you use Vace effect and its corresponding control audio will be concatenated to the rest of the audio + +To multiply the combinations I have also implemented *Continue Video* with the various image2video models. + +Also: +- End Frame support added for LTX Video models +- Loras can now be targetted specifically at the High noise or Low noise models with Wan 2.2, check the Loras and Finetune guides +- Flux Krea Dev support + +### July 30 2025: WanGP v7.5: Just another release ... Wan 2.2 part 2 +Here is now Wan 2.2 image2video a very good model if you want to set Start and End frames. Two Wan 2.2 models delivered, only one to go ... + +Please note that although it is an image2video model it is structurally very close to Wan 2.2 text2video (same layers with only a different initial projection). Given that Wan 2.1 image2video loras don't work too well (half of their tensors are not supported), I have decided that this model will look for its loras in the text2video loras folder instead of the image2video folder. + +I have also optimized RAM management with Wan 2.2 so that loras and modules will be loaded only once in RAM and Reserved RAM, this saves up to 5 GB of RAM which can make a difference... + +And this time I really removed Vace Cocktail Light which gave a blurry vision. + +### July 29 2025: WanGP v7.4: Just another release ... Wan 2.2 Preview +Wan 2.2 is here. The good news is that WanGP wont require a single byte of extra VRAM to run it and it will be as fast as Wan 2.1. The bad news is that you will need much more RAM if you want to leverage entirely this new model since it has twice has many parameters. + +So here is a preview version of Wan 2.2 that is without the 5B model and Wan 2.2 image to video for the moment. + +However as I felt bad to deliver only half of the wares, I gave you instead .....** Wan 2.2 Vace Experimental Cocktail** ! + +Very good surprise indeed, the loras and Vace partially work with Wan 2.2. We will need to wait for the official Vace 2.2 release since some Vace features are broken like identity preservation + +Bonus zone: Flux multi images conditions has been added, or maybe not if I broke everything as I have been distracted by Wan... + +7.4 update: I forgot to update the version number. I also removed Vace Cocktail light which didnt work well. + +### July 27 2025: WanGP v7.3 : Interlude +While waiting for Wan 2.2, you will appreciate the model selection hierarchy which is very useful to collect even more models. You will also appreciate that WanGP remembers which model you used last in each model family. + +### July 26 2025: WanGP v7.2 : Ode to Vace +I am really convinced that Vace can do everything the other models can do and in a better way especially as Vace can be combined with Multitalk. + +Here are some new Vace improvements: +- I have provided a default finetune named *Vace Cocktail* which is a model created on the fly using the Wan text 2 video model and the Loras used to build FusioniX. The weight of the *Detail Enhancer* Lora has been reduced to improve identity preservation. Copy the model definition in *defaults/vace_14B_cocktail.json* in the *finetunes/* folder to change the Cocktail composition. Cocktail contains already some Loras acccelerators so no need to add on top a Lora Accvid, Causvid or Fusionix, ... . The whole point of Cocktail is to be able to build you own FusioniX (which originally is a combination of 4 loras) but without the inconvenient of FusioniX. +- Talking about identity preservation, it tends to go away when one generates a single Frame instead of a Video which is shame for our Vace photoshop. But there is a solution : I have added an Advanced Quality option, that tells WanGP to generate a little more than a frame (it will still keep only the first frame). It will be a little slower but you will be amazed how Vace Cocktail combined with this option will preserve identities (bye bye *Phantom*). +- As in practise I have observed one switches frequently between *Vace text2video* and *Vace text2image* I have put them in the same place they are now just one tab away, no need to reload the model. Likewise *Wan text2video* and *Wan tex2image* have been merged. +- Color fixing when using Sliding Windows. A new postprocessing *Color Correction* applied automatically by default (you can disable it in the *Advanced tab Sliding Window*) will try to match the colors of the new window with that of the previous window. It doesnt fix all the unwanted artifacts of the new window but at least this makes the transition smoother. Thanks to the multitalk team for the original code. + +Also you will enjoy our new real time statistics (CPU / GPU usage, RAM / VRAM used, ... ). Many thanks to **Redtash1** for providing the framework for this new feature ! You need to go in the Config tab to enable real time stats. + + +### July 21 2025: WanGP v7.12 - Flux Family Reunion : *Flux Dev* and *Flux Schnell* have been invited aboard WanGP. To celebrate that, Loras support for the Flux *diffusers* format has also been added. -- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment +- LTX Video upgraded to version 0.9.8: you can now generate 1800 frames (1 min of video !) in one go without a sliding window. With the distilled model it will take only 5 minutes with a RTX 4090 (you will need 22 GB of VRAM though). I have added options to select higher humber frames if you want to experiment (go to Configuration Tab / General / Increase the Max Number of Frames, change the value and restart the App) - LTX Video ControlNet : it is a Control Net that allows you for instance to transfer a Human motion or Depth from a control video. It is not as powerful as Vace but can produce interesting things especially as now you can generate quickly a 1 min video. Under the scene IC-Loras (see below) for Pose, Depth and Canny are automatically loaded for you, no need to add them. - LTX IC-Lora support: these are special Loras that consumes a conditional image or video Beside the pose, depth and canny IC-Loras transparently loaded there is the *detailer* (https://huggingface.co/Lightricks/LTX-Video-ICLoRA-detailer-13b-0.9.8) which is basically an upsampler. Add the *detailer* as a Lora and use LTX Raw Format as control net choice to use it. -And Also: -- easier way to select video resolution -- started to optimize Matanyone to reduce VRAM requirements +- Matanyone is now also for the GPU Poor as its VRAM requirements have been divided by 2! (7.12 shadow update) +- Easier way to select video resolution ### July 15 2025: WanGP v7.0 is an AI Powered Photoshop This release turns the Wan models into Image Generators. This goes way more than allowing to generate a video made of single frame : diff --git a/docs/INSTALLATION.md b/docs/INSTALLATION.md index 9f66422a1..361f26684 100644 --- a/docs/INSTALLATION.md +++ b/docs/INSTALLATION.md @@ -27,7 +27,7 @@ conda activate wan2gp ### Step 2: Install PyTorch ```shell -# Install PyTorch 2.7.0 with CUDA 12.4 +# Install PyTorch 2.7.0 with CUDA 12.8 pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128 ``` diff --git a/entrypoint.sh b/entrypoint.sh new file mode 100755 index 000000000..9af052d07 --- /dev/null +++ b/entrypoint.sh @@ -0,0 +1,118 @@ +#!/usr/bin/env bash +export HOME=/home/user +export PYTHONUNBUFFERED=1 +export HF_HOME=/home/user/.cache/huggingface + +export OMP_NUM_THREADS=$(nproc) +export MKL_NUM_THREADS=$(nproc) +export OPENBLAS_NUM_THREADS=$(nproc) +export NUMEXPR_NUM_THREADS=$(nproc) + +export TORCH_ALLOW_TF32_CUBLAS=1 +export TORCH_ALLOW_TF32_CUDNN=1 + +# Disable audio warnings in Docker +export SDL_AUDIODRIVER=dummy +export PULSE_RUNTIME_PATH=/tmp/pulse-runtime + +# ═══════════════════════════ CUDA DEBUG CHECKS ═══════════════════════════ + +echo "🔍 CUDA Environment Debug Information:" +echo "═══════════════════════════════════════════════════════════════════════" + +# Check CUDA driver on host (if accessible) +if command -v nvidia-smi >/dev/null 2>&1; then + echo "✅ nvidia-smi available" + echo "📊 GPU Information:" + nvidia-smi --query-gpu=name,driver_version,memory.total,memory.free --format=csv,noheader,nounits 2>/dev/null || echo "❌ nvidia-smi failed to query GPU" + echo "🏃 Running Processes:" + nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader,nounits 2>/dev/null || echo "ℹ️ No running CUDA processes" +else + echo "❌ nvidia-smi not available in container" +fi + +# Check CUDA runtime libraries +echo "" +echo "🔧 CUDA Runtime Check:" +if ls /usr/local/cuda*/lib*/libcudart.so* >/dev/null 2>&1; then + echo "✅ CUDA runtime libraries found:" + ls /usr/local/cuda*/lib*/libcudart.so* 2>/dev/null +else + echo "❌ CUDA runtime libraries not found" +fi + +# Check CUDA devices +echo "" +echo "🖥️ CUDA Device Files:" +if ls /dev/nvidia* >/dev/null 2>&1; then + echo "✅ NVIDIA device files found:" + ls -la /dev/nvidia* 2>/dev/null +else + echo "❌ No NVIDIA device files found - Docker may not have GPU access" +fi + +# Check CUDA environment variables +echo "" +echo "🌍 CUDA Environment Variables:" +echo " CUDA_HOME: ${CUDA_HOME:-not set}" +echo " CUDA_ROOT: ${CUDA_ROOT:-not set}" +echo " CUDA_PATH: ${CUDA_PATH:-not set}" +echo " LD_LIBRARY_PATH: ${LD_LIBRARY_PATH:-not set}" +echo " TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-not set}" +echo " CUDA_VISIBLE_DEVICES: ${CUDA_VISIBLE_DEVICES:-not set}" + +# Check PyTorch CUDA availability +echo "" +echo "🐍 PyTorch CUDA Check:" +python3 -c " +import sys +try: + import torch + print('✅ PyTorch imported successfully') + print(f' Version: {torch.__version__}') + print(f' CUDA available: {torch.cuda.is_available()}') + if torch.cuda.is_available(): + print(f' CUDA version: {torch.version.cuda}') + print(f' cuDNN version: {torch.backends.cudnn.version()}') + print(f' Device count: {torch.cuda.device_count()}') + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + print(f' Device {i}: {props.name} (SM {props.major}.{props.minor}, {props.total_memory//1024//1024}MB)') + else: + print('❌ CUDA not available to PyTorch') + print(' This could mean:') + print(' - CUDA runtime not properly installed') + print(' - GPU not accessible to container') + print(' - Driver/runtime version mismatch') +except ImportError as e: + print(f'❌ Failed to import PyTorch: {e}') +except Exception as e: + print(f'❌ PyTorch CUDA check failed: {e}') +" 2>&1 + +# Check for common CUDA issues +echo "" +echo "🩺 Common Issue Diagnostics:" + +# Check if running with proper Docker flags +if [ ! -e /dev/nvidia0 ] && [ ! -e /dev/nvidiactl ]; then + echo "❌ No NVIDIA device nodes - container likely missing --gpus all or --runtime=nvidia" +fi + +# Check CUDA library paths +if [ -z "$LD_LIBRARY_PATH" ] || ! echo "$LD_LIBRARY_PATH" | grep -q cuda; then + echo "⚠️ LD_LIBRARY_PATH may not include CUDA libraries" +fi + +# Check permissions on device files +if ls /dev/nvidia* >/dev/null 2>&1; then + if ! ls -la /dev/nvidia* | grep -q "rw-rw-rw-\|rw-r--r--"; then + echo "⚠️ NVIDIA device files may have restrictive permissions" + fi +fi + +echo "═══════════════════════════════════════════════════════════════════════" +echo "🚀 Starting application..." +echo "" + +exec su -p user -c "python3 wgp.py --listen $*" diff --git a/models/flux/flux_handler.py b/models/flux/flux_handler.py index 162ec4c69..471c339d4 100644 --- a/models/flux/flux_handler.py +++ b/models/flux/flux_handler.py @@ -13,27 +13,52 @@ def query_model_def(base_model_type, model_def): flux_schnell = flux_model == "flux-schnell" flux_chroma = flux_model == "flux-chroma" flux_uso = flux_model == "flux-dev-uso" - model_def_output = { + flux_umo = flux_model == "flux-dev-umo" + flux_kontext = flux_model == "flux-dev-kontext" + + extra_model_def = { "image_outputs" : True, "no_negative_prompt" : not flux_chroma, } if flux_chroma: - model_def_output["guidance_max_phases"] = 1 + extra_model_def["guidance_max_phases"] = 1 elif not flux_schnell: - model_def_output["embedded_guidance"] = True + extra_model_def["embedded_guidance"] = True if flux_uso : - model_def_output["any_image_refs_relative_size"] = True - model_def_output["no_background_removal"] = True - - model_def_output["image_ref_choices"] = { - "choices":[("No Reference Image", ""),("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "I"), - ("Up to two Images are Style Images", "IJ")], - "default": "I", - "letters_filter": "IJ", + extra_model_def["any_image_refs_relative_size"] = True + extra_model_def["no_background_removal"] = True + extra_model_def["image_ref_choices"] = { + "choices":[("First Image is a Reference Image, and then the next ones (up to two) are Style Images", "KI"), + ("Up to two Images are Style Images", "KIJ")], + "default": "KI", + "letters_filter": "KIJ", "label": "Reference Images / Style Images" } + + if flux_kontext: + extra_model_def["inpaint_support"] = True + extra_model_def["image_ref_choices"] = { + "choices": [ + ("None", ""), + ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"), + ("Conditional Images are People / Objects", "I"), + ], + "letters_filter": "KI", + } + extra_model_def["background_removal_label"]= "Remove Backgrounds only behind People / Objects except main Subject / Landscape" + elif flux_umo: + extra_model_def["image_ref_choices"] = { + "choices": [ + ("Conditional Images are People / Objects", "I"), + ], + "letters_filter": "I", + "visible": False + } + - return model_def_output + extra_model_def["fit_into_canvas_image_refs"] = 0 + + return extra_model_def @staticmethod def query_supported_types(): @@ -82,7 +107,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder ] @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .flux_main import model_factory flux_model = model_factory( @@ -107,15 +132,38 @@ def load_model(model_filename, model_type, base_model_type, model_def, quantizeT pipe["feature_embedder"] = flux_model.feature_embedder return flux_model, pipe + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + flux_model = model_def.get("flux-model", "flux-dev") + flux_uso = flux_model == "flux-dev-uso" + if flux_uso and settings_version < 2.29: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "I" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("I", "KI") + ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.34: + ui_defaults["denoising_strength"] = 1. + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): flux_model = model_def.get("flux-model", "flux-dev") flux_uso = flux_model == "flux-dev-uso" + flux_umo = flux_model == "flux-dev-umo" + flux_kontext = flux_model == "flux-dev-kontext" ui_defaults.update({ "embedded_guidance": 2.5, - }) - if model_def.get("reference_image", False): + }) + + if flux_kontext or flux_uso: + ui_defaults.update({ + "video_prompt_type": "KI", + "denoising_strength": 1., + }) + elif flux_umo: ui_defaults.update({ - "video_prompt_type": "I" if flux_uso else "KI", + "video_prompt_type": "I", + "remove_background_images_ref": 0, }) + diff --git a/models/flux/flux_main.py b/models/flux/flux_main.py index 55a2b910a..6746a2375 100644 --- a/models/flux/flux_main.py +++ b/models/flux/flux_main.py @@ -9,6 +9,9 @@ from .sampling import denoise, get_schedule, prepare_kontext, prepare_prompt, prepare_multi_ip, unpack from .modules.layers import get_linear_split_map from transformers import SiglipVisionModel, SiglipImageProcessor +import torchvision.transforms.functional as TVF +import math +from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image from .util import ( aspect_ratio_to_height_width, @@ -20,6 +23,35 @@ ) from PIL import Image +def preprocess_ref(raw_image: Image.Image, long_size: int = 512): + # 获取原始图像的宽度和高度 + image_w, image_h = raw_image.size + + # 计算长边和短边 + if image_w >= image_h: + new_w = long_size + new_h = int((long_size / image_w) * image_h) + else: + new_h = long_size + new_w = int((long_size / image_h) * image_w) + + # 按新的宽高进行等比例缩放 + raw_image = raw_image.resize((new_w, new_h), resample=Image.LANCZOS) + target_w = new_w // 16 * 16 + target_h = new_h // 16 * 16 + + # 计算裁剪的起始坐标以实现中心裁剪 + left = (new_w - target_w) // 2 + top = (new_h - target_h) // 2 + right = left + target_w + bottom = top + target_h + + # 进行中心裁剪 + raw_image = raw_image.crop((left, top, right, bottom)) + + # 转换为 RGB 模式 + raw_image = raw_image.convert("RGB") + return raw_image def stitch_images(img1, img2): # Resize img2 to match img1's height @@ -64,7 +96,7 @@ def __init__( # self.name= "flux-schnell" source = model_def.get("source", None) self.model = load_flow_model(self.name, model_filename[0] if source is None else source, torch_device) - + self.model_def = model_def self.vae = load_ae(self.name, device=torch_device) siglip_processor = siglip_model = feature_embedder = None @@ -106,10 +138,12 @@ def __init__( def generate( self, seed: int | None = None, - input_prompt: str = "replace the logo with the text 'Black Forest Labs'", + input_prompt: str = "replace the logo with the text 'Black Forest Labs'", n_prompt: str = None, sampling_steps: int = 20, input_ref_images = None, + input_frames= None, + input_masks= None, width= 832, height=480, embedded_guidance_scale: float = 2.5, @@ -120,7 +154,8 @@ def generate( batch_size = 1, video_prompt_type = "", joint_pass = False, - image_refs_relative_size = 100, + image_refs_relative_size = 100, + denoising_strength = 1., **bbargs ): if self._interrupt: @@ -129,11 +164,17 @@ def generate( if n_prompt is None or len(n_prompt) == 0: n_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" device="cuda" flux_dev_uso = self.name in ['flux-dev-uso'] - image_stiching = not self.name in ['flux-dev-uso'] + flux_dev_umo = self.name in ['flux-dev-umo'] + latent_stiching = self.name in ['flux-dev-uso', 'flux-dev-umo'] + + lock_dimensions= False input_ref_images = [] if input_ref_images is None else input_ref_images[:] + if flux_dev_umo: + ref_long_side = 512 if len(input_ref_images) <= 1 else 320 + input_ref_images = [preprocess_ref(img, ref_long_side) for img in input_ref_images] + lock_dimensions = True ref_style_imgs = [] - if "I" in video_prompt_type and len(input_ref_images) > 0: if flux_dev_uso : if "J" in video_prompt_type: @@ -142,33 +183,28 @@ def generate( elif len(input_ref_images) > 1 : ref_style_imgs = input_ref_images[-1:] input_ref_images = input_ref_images[:-1] - if image_stiching: + + if latent_stiching: + # latents stiching with resize + if not lock_dimensions : + for i in range(len(input_ref_images)): + w, h = input_ref_images[i].size + image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, 0) + input_ref_images[i] = input_ref_images[i].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + else: # image stiching method stiched = input_ref_images[0] - if "K" in video_prompt_type : - w, h = input_ref_images[0].size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - for new_img in input_ref_images[1:]: stiched = stitch_images(stiched, new_img) input_ref_images = [stiched] - else: - first_ref = 0 - if "K" in video_prompt_type: - # image latents tiling method - w, h = input_ref_images[0].size - height, width = calculate_new_dimensions(height, width, h, w, fit_into_canvas) - input_ref_images[0] = input_ref_images[0].resize((width, height), resample=Image.Resampling.LANCZOS) - first_ref = 1 - - for i in range(first_ref,len(input_ref_images)): - w, h = input_ref_images[i].size - image_height, image_width = calculate_new_dimensions(int(height*image_refs_relative_size/100), int(width*image_refs_relative_size/100), h, w, fit_into_canvas) - input_ref_images[0] = input_ref_images[0].resize((image_width, image_height), resample=Image.Resampling.LANCZOS) + elif input_frames is not None: + input_ref_images = [convert_tensor_to_image(input_frames) ] else: input_ref_images = None + image_mask = None if input_masks is None else convert_tensor_to_image(input_masks, mask_levels= True) + - if flux_dev_uso : + if self.name in ['flux-dev-uso', 'flux-dev-umo'] : inp, height, width = prepare_multi_ip( ae=self.vae, img_cond_list=input_ref_images, @@ -187,6 +223,7 @@ def generate( bs=batch_size, seed=seed, device=device, + img_mask=image_mask, ) inp.update(prepare_prompt(self.t5, self.clip, batch_size, input_prompt)) @@ -208,13 +245,19 @@ def unpack_latent(x): return unpack(x.float(), height, width) # denoise initial noise - x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass) + x = denoise(self.model, **inp, timesteps=timesteps, guidance=embedded_guidance_scale, real_guidance_scale =guide_scale, callback=callback, pipeline=self, loras_slists= loras_slists, unpack_latent = unpack_latent, joint_pass = joint_pass, denoising_strength = denoising_strength) if x==None: return None # decode latents to pixel space x = unpack_latent(x) with torch.autocast(device_type=device, dtype=torch.bfloat16): x = self.vae.decode(x) + if image_mask is not None: + from shared.utils.utils import convert_image_to_tensor + img_msk_rebuilt = inp["img_msk_rebuilt"] + img= input_frames.squeeze(1).unsqueeze(0) # convert_image_to_tensor(image_guide) + x = img * (1 - img_msk_rebuilt) + x.to(img) * img_msk_rebuilt + x = x.clamp(-1, 1) x = x.transpose(0, 1) return x diff --git a/models/flux/model.py b/models/flux/model.py index c4642d0bc..c5f7a24b4 100644 --- a/models/flux/model.py +++ b/models/flux/model.py @@ -190,6 +190,21 @@ def swap_scale_shift(weight): v = swap_scale_shift(v) k = k.replace("norm_out.linear", "final_layer.adaLN_modulation.1") new_sd[k] = v + # elif not first_key.startswith("diffusion_model.") and not first_key.startswith("transformer."): + # for k,v in sd.items(): + # if "double" in k: + # k = k.replace(".processor.proj_lora1.", ".img_attn.proj.lora_") + # k = k.replace(".processor.proj_lora2.", ".txt_attn.proj.lora_") + # k = k.replace(".processor.qkv_lora1.", ".img_attn.qkv.lora_") + # k = k.replace(".processor.qkv_lora2.", ".txt_attn.qkv.lora_") + # else: + # k = k.replace(".processor.qkv_lora.", ".linear1_qkv.lora_") + # k = k.replace(".processor.proj_lora.", ".linear2.lora_") + + # k = "diffusion_model." + k + # new_sd[k] = v + # from mmgp import safetensors2 + # safetensors2.torch_write_file(new_sd, "fff.safetensors") else: new_sd = sd return new_sd diff --git a/models/flux/sampling.py b/models/flux/sampling.py index 5534e9f6b..939543c64 100644 --- a/models/flux/sampling.py +++ b/models/flux/sampling.py @@ -138,10 +138,12 @@ def prepare_kontext( target_width: int | None = None, target_height: int | None = None, bs: int = 1, - + img_mask = None, ) -> tuple[dict[str, Tensor], int, int]: # load and encode the conditioning image + res_match_output = img_mask is not None + img_cond_seq = None img_cond_seq_ids = None if img_cond_list == None: img_cond_list = [] @@ -150,10 +152,11 @@ def prepare_kontext( for cond_no, img_cond in enumerate(img_cond_list): width, height = img_cond.size aspect_ratio = width / height - - # Kontext is trained on specific resolutions, using one of them is recommended - _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) - + if res_match_output: + width, height = target_width, target_height + else: + # Kontext is trained on specific resolutions, using one of them is recommended + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) width = 2 * int(width / 16) height = 2 * int(height / 16) @@ -194,6 +197,19 @@ def prepare_kontext( "img_cond_seq": img_cond_seq, "img_cond_seq_ids": img_cond_seq_ids, } + if img_mask is not None: + from shared.utils.utils import convert_image_to_tensor, convert_tensor_to_image + # image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) + image_mask_latents = convert_image_to_tensor(img_mask.resize((target_width // 16, target_height // 16), resample=Image.Resampling.LANCZOS)) + image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] + image_mask_rebuilt = image_mask_latents.repeat_interleave(16, dim=-1).repeat_interleave(16, dim=-2).unsqueeze(0) + # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + image_mask_latents = image_mask_latents.reshape(1, -1, 1).to(device) + return_dict.update({ + "img_msk_latents": image_mask_latents, + "img_msk_rebuilt": image_mask_rebuilt, + }) + img = get_noise( bs, target_height, @@ -265,6 +281,9 @@ def denoise( loras_slists=None, unpack_latent = None, joint_pass= False, + img_msk_latents = None, + img_msk_rebuilt = None, + denoising_strength = 1, ): kwargs = {'pipeline': pipeline, 'callback': callback, "img_len" : img.shape[1], "siglip_embedding": siglip_embedding, "siglip_embedding_ids": siglip_embedding_ids} @@ -272,19 +291,38 @@ def denoise( if callback != None: callback(-1, None, True) + original_image_latents = None if img_cond_seq is None else img_cond_seq.clone() + original_timesteps = timesteps + morph, first_step = False, 0 + if img_msk_latents is not None: + randn = torch.randn_like(original_image_latents) + if denoising_strength < 1.: + first_step = int(len(timesteps) * (1. - denoising_strength)) + if not morph: + latent_noise_factor = timesteps[first_step] + latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + img = latents.to(img) + latents = None + timesteps = timesteps[first_step:] + + updated_num_steps= len(timesteps) -1 if callback != None: from shared.utils.loras_mutipliers import update_loras_slists - update_loras_slists(model, loras_slists, updated_num_steps) + update_loras_slists(model, loras_slists, len(original_timesteps)) callback(-1, None, True, override_num_inference_steps = updated_num_steps) from mmgp import offload # this is ignored for schnell guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype) for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])): - offload.set_step_no_for_lora(model, i) + offload.set_step_no_for_lora(model, first_step + i) if pipeline._interrupt: return None + if img_msk_latents is not None and denoising_strength <1. and i == first_step and morph: + latent_noise_factor = t_curr/1000 + img = original_image_latents * (1.0 - latent_noise_factor) + img * latent_noise_factor + t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device) img_input = img img_input_ids = img_ids @@ -334,6 +372,14 @@ def denoise( pred = neg_pred + real_guidance_scale * (pred - neg_pred) img += (t_prev - t_curr) * pred + + if img_msk_latents is not None: + latent_noise_factor = t_prev + # noisy_image = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + noisy_image = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + img = noisy_image * (1-img_msk_latents) + img_msk_latents * img + noisy_image = None + if callback is not None: preview = unpack_latent(img).transpose(0,1) callback(i, preview, False) diff --git a/models/flux/util.py b/models/flux/util.py index 0f96103de..af75f6269 100644 --- a/models/flux/util.py +++ b/models/flux/util.py @@ -640,6 +640,38 @@ class ModelSpec: shift_factor=0.1159, ), ), + "flux-dev-umo": ModelSpec( + repo_id="", + repo_flow="", + repo_ae="ckpts/flux_vae.safetensors", + params=FluxParams( + in_channels=64, + out_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + eso= True, + ), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), } diff --git a/models/hyvideo/hunyuan.py b/models/hyvideo/hunyuan.py index a38a7bd0d..aa6c3b3e6 100644 --- a/models/hyvideo/hunyuan.py +++ b/models/hyvideo/hunyuan.py @@ -861,16 +861,11 @@ def generate( freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_RIFLEx) else: if self.avatar: - w, h = input_ref_images.size - target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas) - if target_width != w or target_height != h: - input_ref_images = input_ref_images.resize((target_width,target_height), resample=Image.Resampling.LANCZOS) - concat_dict = {'mode': 'timecat', 'bias': -1} freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict) else: if input_frames != None: - target_height, target_width = input_frames.shape[-3:-1] + target_height, target_width = input_frames.shape[-2:] elif input_video != None: target_height, target_width = input_video.shape[-2:] @@ -899,9 +894,10 @@ def generate( pixel_value_bg = input_video.unsqueeze(0) pixel_value_mask = torch.zeros_like(input_video).unsqueeze(0) if input_frames != None: - pixel_value_video_bg = input_frames.permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() - pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + pixel_value_video_bg = input_frames.unsqueeze(0) #.permute(-1,0,1,2).unsqueeze(0).float() + # pixel_value_video_bg = pixel_value_video_bg.div_(127.5).add_(-1.) + # pixel_value_video_mask = input_masks.unsqueeze(-1).repeat(1,1,1,3).permute(-1,0,1,2).unsqueeze(0).float() + pixel_value_video_mask = input_masks.repeat(3,1,1,1).unsqueeze(0) if input_video != None: pixel_value_bg = torch.cat([pixel_value_bg, pixel_value_video_bg], dim=2) pixel_value_mask = torch.cat([ pixel_value_mask, pixel_value_video_mask], dim=2) @@ -913,10 +909,11 @@ def generate( if pixel_value_bg.shape[2] < frame_num: padding_shape = list(pixel_value_bg.shape[0:2]) + [frame_num-pixel_value_bg.shape[2]] + list(pixel_value_bg.shape[3:]) pixel_value_bg = torch.cat([pixel_value_bg, torch.full(padding_shape, -1, dtype=pixel_value_bg.dtype, device= pixel_value_bg.device ) ], dim=2) - pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + # pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 255, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) + pixel_value_mask = torch.cat([ pixel_value_mask, torch.full(padding_shape, 1, dtype=pixel_value_mask.dtype, device= pixel_value_mask.device ) ], dim=2) bg_latents = self.vae.encode(pixel_value_bg).latent_dist.sample() - pixel_value_mask = pixel_value_mask.div_(127.5).add_(-1.) + pixel_value_mask = pixel_value_mask.mul_(2).add_(-1.) # unmasked pixels is -1 (no 0 as usual) and masked is 1 mask_latents = self.vae.encode(pixel_value_mask).latent_dist.sample() bg_latents = torch.cat([bg_latents, mask_latents], dim=1) bg_latents.mul_(self.vae.config.scaling_factor) diff --git a/models/hyvideo/hunyuan_handler.py b/models/hyvideo/hunyuan_handler.py index d95bd7e35..ebe07d26d 100644 --- a/models/hyvideo/hunyuan_handler.py +++ b/models/hyvideo/hunyuan_handler.py @@ -51,11 +51,38 @@ def query_model_def(base_model_type, model_def): extra_model_def["tea_cache"] = True extra_model_def["mag_cache"] = True - if base_model_type in ["hunyuan_avatar"]: extra_model_def["no_background_removal"] = True - - if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_audio", "hunyuan_avatar"]: + if base_model_type in ["hunyuan_custom_edit"]: + extra_model_def["guide_preprocessing"] = { + "selection": ["MV", "PV"], + } + + extra_model_def["mask_preprocessing"] = { + "selection": ["A", "NA"], + "default" : "NA" + } + + if base_model_type in ["hunyuan_custom_audio", "hunyuan_custom_edit", "hunyuan_custom"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Reference Image", "I")], + "letters_filter":"I", + "visible": False, + } + + if base_model_type in ["hunyuan_avatar"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Start Image", "KI")], + "letters_filter":"KI", + "visible": False, + } + extra_model_def["no_background_removal"] = True + + if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar"]: extra_model_def["one_image_ref_needed"] = True + + if base_model_type in ["hunyuan_i2v"]: + extra_model_def["image_prompt_types_allowed"] = "S" + return extra_model_def @staticmethod @@ -102,7 +129,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder } @staticmethod - def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type = None, base_model_type = None, model_def = None, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .hunyuan import HunyuanVideoSampler from mmgp import offload @@ -137,6 +164,24 @@ def load_model(model_filename, model_type = None, base_model_type = None, model return hunyuan_model, pipe + @staticmethod + def fix_settings(base_model_type, settings_version, model_def, ui_defaults): + if settings_version<2.33: + if base_model_type in ["hunyuan_custom_edit"]: + video_prompt_type= ui_defaults["video_prompt_type"] + if "P" in video_prompt_type and "M" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("M","") + ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.36: + if base_model_type in ["hunyuan_avatar", "hunyuan_custom_audio"]: + audio_prompt_type= ui_defaults["audio_prompt_type"] + if "A" not in audio_prompt_type: + audio_prompt_type += "A" + ui_defaults["audio_prompt_type"] = audio_prompt_type + + + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults["embedded_guidance_scale"]= 6.0 @@ -158,6 +203,7 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "guidance_scale": 7.5, "flow_shift": 13, "video_prompt_type": "I", + "audio_prompt_type": "A", }) elif base_model_type in ["hunyuan_custom_edit"]: ui_defaults.update({ @@ -174,4 +220,5 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "skip_steps_start_step_perc": 25, "video_length": 129, "video_prompt_type": "KI", + "audio_prompt_type": "A", }) diff --git a/models/hyvideo/modules/utils.py b/models/hyvideo/modules/utils.py index 02a733e1b..c26399728 100644 --- a/models/hyvideo/modules/utils.py +++ b/models/hyvideo/modules/utils.py @@ -14,7 +14,7 @@ ) -@lru_cache +# @lru_cache def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False): block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile) return block_mask diff --git a/models/ltx_video/ltxv.py b/models/ltx_video/ltxv.py index e71ac4fce..080860c14 100644 --- a/models/ltx_video/ltxv.py +++ b/models/ltx_video/ltxv.py @@ -300,9 +300,6 @@ def generate( prefix_size, height, width = input_video.shape[-3:] else: if image_start != None: - frame_width, frame_height = image_start.size - if fit_into_canvas != None: - height, width = calculate_new_dimensions(height, width, frame_height, frame_width, fit_into_canvas, 32) conditioning_media_paths.append(image_start.unsqueeze(1)) conditioning_start_frames.append(0) conditioning_control_frames.append(False) @@ -479,14 +476,14 @@ def generate( images = images.sub_(0.5).mul_(2).squeeze(0) return images - def get_loras_transformer(self, get_model_recursive_prop, video_prompt_type, **kwargs): + def get_loras_transformer(self, get_model_recursive_prop, model_type, video_prompt_type, **kwargs): map = { "P" : "pose", "D" : "depth", "E" : "canny", } loras = [] - preloadURLs = get_model_recursive_prop(self.model_type, "preload_URLs") + preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") lora_file_name = "" for letter, signature in map.items(): if letter in video_prompt_type: diff --git a/models/ltx_video/ltxv_handler.py b/models/ltx_video/ltxv_handler.py index d35bcd457..8c322e153 100644 --- a/models/ltx_video/ltxv_handler.py +++ b/models/ltx_video/ltxv_handler.py @@ -24,7 +24,19 @@ def query_model_def(base_model_type, model_def): extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 8 extra_model_def["sliding_window"] = True + extra_model_def["image_prompt_types_allowed"] = "TSEV" + extra_model_def["guide_preprocessing"] = { + "selection": ["", "PV", "DV", "EV", "V"], + "labels" : { "V": "Use LTXV raw format"} + } + + extra_model_def["mask_preprocessing"] = { + "selection": ["", "A", "NA", "XA", "XNA"], + } + + extra_model_def["extra_control_frames"] = 1 + extra_model_def["dont_cat_preguide"]= True return extra_model_def @staticmethod @@ -64,7 +76,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized = False, submodel_no_list = None): from .ltxv import LTXV ltxv_model = LTXV( diff --git a/models/qwen/pipeline_qwenimage.py b/models/qwen/pipeline_qwenimage.py index 07bdbd404..c458852fe 100644 --- a/models/qwen/pipeline_qwenimage.py +++ b/models/qwen/pipeline_qwenimage.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from mmgp import offload import inspect from typing import Any, Callable, Dict, List, Optional, Union @@ -28,7 +27,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage from diffusers import FlowMatchEulerDiscreteScheduler from PIL import Image -from shared.utils.utils import calculate_new_dimensions +from shared.utils.utils import calculate_new_dimensions, convert_image_to_tensor, convert_tensor_to_image XLA_AVAILABLE = False @@ -201,7 +200,8 @@ def __init__( self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = 1024 if processor is not None: - self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + # self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" + self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.prompt_template_encode_start_idx = 64 else: self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" @@ -232,7 +232,22 @@ def _get_qwen_prompt_embeds( drop_idx = self.prompt_template_encode_start_idx txt = [template.format(e) for e in prompt] - if self.processor is not None and image is not None: + if self.processor is not None and image is not None and len(image) > 0: + img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>" + if isinstance(image, list): + base_img_prompt = "" + for i, img in enumerate(image): + base_img_prompt += img_prompt_template.format(i + 1) + elif image is not None: + base_img_prompt = img_prompt_template.format(1) + else: + base_img_prompt = "" + + template = self.prompt_template_encode + + drop_idx = self.prompt_template_encode_start_idx + txt = [template.format(base_img_prompt + e) for e in prompt] + model_inputs = self.processor( text=txt, images=image, @@ -387,7 +402,8 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): + def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) @@ -464,7 +480,7 @@ def disable_vae_tiling(self): def prepare_latents( self, - image, + images, batch_size, num_channels_latents, height, @@ -479,30 +495,33 @@ def prepare_latents( height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - shape = (batch_size, 1, num_channels_latents, height, width) + shape = (batch_size, num_channels_latents, 1, height, width) image_latents = None - if image is not None: - image = image.to(device=device, dtype=dtype) - if image.shape[1] != self.latent_channels: - image_latents = self._encode_vae_image(image=image, generator=generator) - else: - image_latents = image - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - image_latent_height, image_latent_width = image_latents.shape[3:] - image_latents = self._pack_latents( - image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width - ) + if images is not None and len(images ) > 0: + if not isinstance(images, list): + images = [images] + all_image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + image_latents = self._pack_latents(image_latents) + all_image_latents.append(image_latents) + image_latents = torch.cat(all_image_latents, dim=1) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -511,7 +530,7 @@ def prepare_latents( ) if latents is None: latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + latents = self._pack_latents(latents) else: latents = latents.to(device=device, dtype=dtype) @@ -563,10 +582,15 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, image = None, + image_mask = None, + denoising_strength = 0, callback=None, pipeline=None, loras_slists=None, joint_pass= True, + lora_inpaint = False, + outpainting_dims = None, + qwen_edit_plus = False, ): r""" Function invoked when calling the pipeline for generation. @@ -682,33 +706,54 @@ def __call__( batch_size = prompt_embeds.shape[0] device = "cuda" - prompt_image = None - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): - image = image[0] if isinstance(image, list) else image - image_height, image_width = self.image_processor.get_default_height_width(image) - aspect_ratio = image_width / image_height - if False : - _, image_width, image_height = min( - (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_QWENIMAGE_RESOLUTIONS - ) - image_width = image_width // multiple_of * multiple_of - image_height = image_height // multiple_of * multiple_of - ref_height, ref_width = 1568, 672 - if height * width < ref_height * ref_width: ref_height , ref_width = height , width - if image_height * image_width > ref_height * ref_width: - image_height, image_width = calculate_new_dimensions(ref_height, ref_width, image_height, image_width, False, block_size=multiple_of) - - image = image.resize((image_width,image_height), resample=Image.Resampling.LANCZOS) - prompt_image = image - image = self.image_processor.preprocess(image, image_height, image_width) - image = image.unsqueeze(2) + condition_images = [] + vae_image_sizes = [] + vae_images = [] + image_mask_latents = None + ref_size = 1024 + ref_text_encoder_size = 384 if qwen_edit_plus else 1024 + if image is not None: + if not isinstance(image, list): image = [image] + if height * width < ref_size * ref_size: ref_size = round(math.sqrt(height * width)) + for ref_no, img in enumerate(image): + image_width, image_height = img.size + any_mask = ref_no == 0 and image_mask is not None + if (image_height * image_width > ref_size * ref_size) and not any_mask: + vae_height, vae_width =calculate_new_dimensions(ref_size, ref_size, image_height, image_width, False, block_size=multiple_of) + else: + vae_height, vae_width = image_height, image_width + vae_width = vae_width // multiple_of * multiple_of + vae_height = vae_height // multiple_of * multiple_of + vae_image_sizes.append((vae_width, vae_height)) + condition_height, condition_width =calculate_new_dimensions(ref_text_encoder_size, ref_text_encoder_size, image_height, image_width, False, block_size=multiple_of) + condition_images.append(img.resize((condition_width, condition_height), resample=Image.Resampling.LANCZOS) ) + if img.size != (vae_width, vae_height): + img = img.resize((vae_width, vae_height), resample=Image.Resampling.LANCZOS) + if any_mask : + if lora_inpaint: + image_mask_rebuilt = torch.where(convert_image_to_tensor(image_mask)>-0.5, 1., 0. )[0:1] + img = convert_image_to_tensor(img) + green = torch.tensor([-1.0, 1.0, -1.0]).to(img) + green_image = green[:, None, None] .expand_as(img) + img = torch.where(image_mask_rebuilt > 0, green_image, img) + img = convert_tensor_to_image(img) + else: + image_mask_latents = convert_image_to_tensor(image_mask.resize((vae_width // 8, vae_height // 8), resample=Image.Resampling.LANCZOS)) + image_mask_latents = torch.where(image_mask_latents>-0.5, 1., 0. )[0:1] + image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) + # convert_tensor_to_image( image_mask_rebuilt.squeeze(0).repeat(3,1,1)).save("mmm.png") + image_mask_latents = image_mask_latents.to(device).unsqueeze(0).unsqueeze(0).repeat(1,16,1,1,1) + image_mask_latents = self._pack_latents(image_mask_latents) + # img.save("nnn.png") + vae_images.append( convert_image_to_tensor(img).unsqueeze(0).unsqueeze(2) ) + has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( - image=prompt_image, + image=condition_images, prompt=prompt, prompt_embeds=prompt_embeds, prompt_embeds_mask=prompt_embeds_mask, @@ -718,7 +763,7 @@ def __call__( ) if do_true_cfg: negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt( - image=prompt_image, + image=condition_images, prompt=negative_prompt, prompt_embeds=negative_prompt_embeds, prompt_embeds_mask=negative_prompt_embeds_mask, @@ -734,7 +779,7 @@ def __call__( # 4. Prepare latent variables num_channels_latents = self.transformer.in_channels // 4 latents, image_latents = self.prepare_latents( - image, + vae_images, batch_size * num_images_per_prompt, num_channels_latents, height, @@ -744,11 +789,18 @@ def __call__( generator, latents, ) + original_image_latents = None if image_latents is None else image_latents.clone() + if image is not None: img_shapes = [ [ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2), - (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + # (1, image_height // self.vae_scale_factor // 2, image_width // self.vae_scale_factor // 2), + *[ + (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2) + for vae_width, vae_height in vae_image_sizes + ], + ] ] * batch_size else: @@ -773,7 +825,7 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - + original_timesteps = timesteps # handle guidance if self.transformer.guidance_embeds: guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) @@ -788,56 +840,80 @@ def __call__( negative_txt_seq_lens = ( negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None ) - + morph, first_step = False, 0 + lanpaint_proc = None + if image_mask_latents is not None: + randn = torch.randn_like(original_image_latents) + if denoising_strength < 1.: + first_step = int(len(timesteps) * (1. - denoising_strength)) + if not morph: + latent_noise_factor = timesteps[first_step]/1000 + # latents = original_image_latents * (1.0 - latent_noise_factor) + torch.randn_like(original_image_latents) * latent_noise_factor + latents = original_image_latents * (1.0 - latent_noise_factor) + randn * latent_noise_factor + timesteps = timesteps[first_step:] + self.scheduler.timesteps = timesteps + self.scheduler.sigmas= self.scheduler.sigmas[first_step:] + # from shared.inpainting.lanpaint import LanPaint + # lanpaint_proc = LanPaint() # 6. Denoising loop self.scheduler.set_begin_index(0) updated_num_steps= len(timesteps) if callback != None: from shared.utils.loras_mutipliers import update_loras_slists - update_loras_slists(self.transformer, loras_slists, updated_num_steps) + update_loras_slists(self.transformer, loras_slists, len(original_timesteps)) callback(-1, None, True, override_num_inference_steps = updated_num_steps) + for i, t in enumerate(timesteps): + offload.set_step_no_for_lora(self.transformer, first_step + i) if self.interrupt: continue - self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents - if image_latents is not None: - latent_model_input = torch.cat([latents, image_latents], dim=1) - - if do_true_cfg and joint_pass: - noise_pred, neg_noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - ) - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] - neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - else: - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states_mask_list=[prompt_embeds_mask], - encoder_hidden_states_list=[prompt_embeds], - img_shapes=img_shapes, - txt_seq_lens_list=[txt_seq_lens], - attention_kwargs=self.attention_kwargs, - **kwargs - )[0] - if noise_pred == None: return None - noise_pred = noise_pred[:, : latents.size(1)] + if image_mask_latents is not None and denoising_strength <1. and i == first_step and morph: + latent_noise_factor = t/1000 + latents = original_image_latents * (1.0 - latent_noise_factor) + latents * latent_noise_factor + + + latents_dtype = latents.dtype + + # latent_model_input = latents + def denoise(latent_model_input, true_cfg_scale): + if image_latents is not None: + latent_model_input = torch.cat([latents, image_latents], dim=1) + do_true_cfg = true_cfg_scale > 1 + if do_true_cfg and joint_pass: + noise_pred, neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, #!!!! + encoder_hidden_states_mask_list=[prompt_embeds_mask,negative_prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds, negative_prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens, negative_txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + ) + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + else: + neg_noise_pred = None + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states_mask_list=[prompt_embeds_mask], + encoder_hidden_states_list=[prompt_embeds], + img_shapes=img_shapes, + txt_seq_lens_list=[txt_seq_lens], + attention_kwargs=self.attention_kwargs, + **kwargs + )[0] + if noise_pred == None: return None, None + noise_pred = noise_pred[:, : latents.size(1)] if do_true_cfg: neg_noise_pred = self.transformer( @@ -851,20 +927,43 @@ def __call__( attention_kwargs=self.attention_kwargs, **kwargs )[0] - if neg_noise_pred == None: return None + if neg_noise_pred == None: return None, None neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + return noise_pred, neg_noise_pred + def cfg_predictions( noise_pred, neg_noise_pred, guidance, t): + if do_true_cfg: + comb_pred = neg_noise_pred + guidance * (noise_pred - neg_noise_pred) + if comb_pred == None: return None + + cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) + noise_pred = comb_pred * (cond_norm / noise_norm) + + return noise_pred - if do_true_cfg: - comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) - if comb_pred == None: return None - cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True) - noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True) - noise_pred = comb_pred * (cond_norm / noise_norm) - neg_noise_pred = None + if lanpaint_proc is not None and i<=3: + latents = lanpaint_proc(denoise, cfg_predictions, true_cfg_scale, 1., latents, original_image_latents, randn, t/1000, image_mask_latents, height=height , width= width, vae_scale_factor= 8) + if latents is None: return None + + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None + noise_pred = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + neg_noise_pred = None # compute the previous noisy sample x_t -> x_t-1 - latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + noise_pred = None + + if image_mask_latents is not None: + if lanpaint_proc is not None: + latents = original_image_latents * (1-image_mask_latents) + image_mask_latents * latents + else: + next_t = timesteps[i+1] if i 0: + dim -= 1 + return TensorList([u.squeeze(dim) for u in self.tensors]) + + def type(self, *args, **kwargs): + return TensorList([u.type(*args, **kwargs) for u in self.tensors]) + + def type_as(self, other): + assert isinstance(other, (torch.Tensor, TensorList)) + if isinstance(other, torch.Tensor): + return TensorList([u.type_as(other) for u in self.tensors]) + else: + return TensorList([u.type(other.dtype) for u in self.tensors]) + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def device(self): + return self.tensors[0].device + + @property + def ndim(self): + return 1 + self.tensors[0].ndim + + def __getitem__(self, index): + return self.tensors[index] + + def __len__(self): + return len(self.tensors) + + def __add__(self, other): + return self._apply(other, lambda u, v: u + v) + + def __radd__(self, other): + return self._apply(other, lambda u, v: v + u) + + def __sub__(self, other): + return self._apply(other, lambda u, v: u - v) + + def __rsub__(self, other): + return self._apply(other, lambda u, v: v - u) + + def __mul__(self, other): + return self._apply(other, lambda u, v: u * v) + + def __rmul__(self, other): + return self._apply(other, lambda u, v: v * u) + + def __floordiv__(self, other): + return self._apply(other, lambda u, v: u // v) + + def __truediv__(self, other): + return self._apply(other, lambda u, v: u / v) + + def __rfloordiv__(self, other): + return self._apply(other, lambda u, v: v // u) + + def __rtruediv__(self, other): + return self._apply(other, lambda u, v: v / u) + + def __pow__(self, other): + return self._apply(other, lambda u, v: u ** v) + + def __rpow__(self, other): + return self._apply(other, lambda u, v: v ** u) + + def __neg__(self): + return TensorList([-u for u in self.tensors]) + + def __iter__(self): + for tensor in self.tensors: + yield tensor + + def __repr__(self): + return 'TensorList: \n' + repr(self.tensors) + + def _apply(self, other, op): + if isinstance(other, (list, tuple, TensorList)) or ( + isinstance(other, torch.Tensor) and ( + other.numel() > 1 or other.ndim > 1 + ) + ): + assert len(other) == len(self.tensors) + return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) + elif isinstance(other, numbers.Number) or ( + isinstance(other, torch.Tensor) and ( + other.numel() == 1 and other.ndim <= 1 + ) + ): + return TensorList([op(u, other) for u in self.tensors]) + else: + raise TypeError( + f'unsupported operand for *: "TensorList" and "{type(other)}"' + ) \ No newline at end of file diff --git a/models/wan/animate/face_blocks.py b/models/wan/animate/face_blocks.py new file mode 100644 index 000000000..8ddb829fa --- /dev/null +++ b/models/wan/animate/face_blocks.py @@ -0,0 +1,382 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +from torch import nn +import torch +from typing import Tuple, Optional +from einops import rearrange +import torch.nn.functional as F +import math +from shared.attention import pay_attention + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + """ + Perform QKV self attention. + + Args: + q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. + k (torch.Tensor): Key tensor with shape [b, s1, a, d] + v (torch.Tensor): Value tensor with shape [b, s1, a, d] + mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. + drop_rate (float): Dropout rate in attention map. (default: 0) + attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). + (default: None) + causal (bool): Whether to use causal attention. (default: False) + cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into q. + cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, + used to index into kv. + max_seqlen_q (int): The maximum sequence length in the batch of q. + max_seqlen_kv (int): The maximum sequence length in the batch of k and v. + + Returns: + torch.Tensor: Output tensor after self attention with shape [b, s, ad] + """ + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + elif mode == "flash": + x = flash_attn_func( + q, + k, + v, + ) + x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d] + elif mode == "vanilla": + scale_factor = 1 / math.sqrt(q.size(-1)) + + b, a, s, _ = q.shape + s1 = k.size(2) + attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) + if causal: + # Only applied to self attention + assert attn_mask is None, "Causal mask and attn_mask cannot be used together" + temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(q.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + + attn = (q @ k.transpose(-2, -1)) * scale_factor + attn += attn_bias + attn = attn.softmax(dim=-1) + attn = torch.dropout(attn, p=drop_rate, train=True) + x = attn @ v + else: + raise NotImplementedError(f"Unsupported attention mode: {mode}") + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) N H D") + v = rearrange(v, "B L N H D -> (B L) N H D") + + if use_context_parallel: + q = gather_forward(q, dim=1) + + q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) + # Compute attention. + # Size([batches, tokens, heads, head_features]) + qkv_list = [q, k, v] + del q,k,v + attn = pay_attention(qkv_list) + # attn = attention( + # q, + # k, + # v, + # max_seqlen_q=q.shape[1], + # batch_size=q.shape[0], + # ) + + attn = attn.reshape(*attn.shape[:2], -1) + attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) + # if use_context_parallel: + # attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output \ No newline at end of file diff --git a/models/wan/animate/model_animate.py b/models/wan/animate/model_animate.py new file mode 100644 index 000000000..d07f762ee --- /dev/null +++ b/models/wan/animate/model_animate.py @@ -0,0 +1,31 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +import types +from copy import deepcopy +from einops import rearrange +from typing import List +import numpy as np +import torch +import torch.cuda.amp as amp +import torch.nn as nn + +def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec diff --git a/models/wan/animate/motion_encoder.py b/models/wan/animate/motion_encoder.py new file mode 100644 index 000000000..02b00408c --- /dev/null +++ b/models/wan/animate/motion_encoder.py @@ -0,0 +1,308 @@ +# Modified from ``https://github.com/wyhsirius/LIA`` +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import torch +import torch.nn as nn +from torch.nn import functional as F +import math + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype in [torch.bfloat16, torch.float16]: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + # motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + with torch.cuda.amp.autocast(dtype=torch.float32): + motion_feat = self.enc.enc_motion(img) + motion = self.dec.direction(motion_feat) + return motion \ No newline at end of file diff --git a/models/wan/any2video.py b/models/wan/any2video.py index ee58c0a09..5beea8e74 100644 --- a/models/wan/any2video.py +++ b/models/wan/any2video.py @@ -19,7 +19,8 @@ import torchvision.transforms.functional as TF import torch.nn.functional as F from .distributed.fsdp import shard_model -from .modules.model import WanModel, clear_caches +from .modules.model import WanModel +from mmgp.offload import get_cache, clear_caches from .modules.t5 import T5EncoderModel from .modules.vae import WanVAE from .modules.vae2_2 import Wan2_2_VAE @@ -31,9 +32,11 @@ from .modules.posemb_layers import get_rotary_pos_embed, get_nd_rotary_pos_embed from shared.utils.vace_preprocessor import VaceVideoProcessor from shared.utils.basic_flowmatch import FlowMatchScheduler -from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor +from shared.utils.utils import get_outpainting_frame_location, resize_lanczos, calculate_new_dimensions, convert_image_to_tensor, fit_image_into_canvas from .multitalk.multitalk_utils import MomentumBuffer, adaptive_projected_guidance, match_and_blend_colors, match_and_blend_colors_with_mask +from shared.utils.audio_video import save_video from mmgp import safetensors2 +from shared.utils.audio_video import save_video def optimized_scale(positive_flat, negative_flat): @@ -63,6 +66,7 @@ def __init__( config, checkpoint_dir, model_filename = None, + submodel_no_list = None, model_type = None, model_def = None, base_model_type = None, @@ -91,7 +95,7 @@ def __init__( shard_fn= None) # base_model_type = "i2v2_2" - if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"]: + if hasattr(config, "clip_checkpoint") and not base_model_type in ["i2v_2_2", "i2v_2_2_multitalk"] or base_model_type in ["animate"]: self.clip = CLIPModel( dtype=config.clip_dtype, device=self.device, @@ -100,7 +104,7 @@ def __init__( tokenizer_path=os.path.join(checkpoint_dir , "xlm-roberta-large")) - if base_model_type in ["ti2v_2_2"]: + if base_model_type in ["ti2v_2_2", "lucy_edit"]: self.vae_stride = (4, 16, 16) vae_checkpoint = "Wan2.2_VAE.safetensors" vae = Wan2_2_VAE @@ -125,75 +129,82 @@ def __init__( forcedConfigPath = base_config_file if len(model_filename) > 1 else None # forcedConfigPath = base_config_file = f"configs/flf2v_720p.json" # model_filename[1] = xmodel_filename - + self.model = self.model2 = None source = model_def.get("source", None) + source2 = model_def.get("source2", None) module_source = model_def.get("module_source", None) + module_source2 = model_def.get("module_source2", None) if module_source is not None: - model_filename = [] + model_filename - model_filename[1] = module_source - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - elif source is not None: + self.model = offload.fast_load_transformers_model(model_filename[:1] + [module_source], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if module_source2 is not None: + self.model2 = offload.fast_load_transformers_model(model_filename[1:2] + [module_source2], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if source is not None: self.model = offload.fast_load_transformers_model(source, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) - elif self.transformer_switch: - shared_modules= {} - self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) - self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - shared_modules = None + if source2 is not None: + self.model2 = offload.fast_load_transformers_model(source2, modelClass=WanModel, writable_tensors= False, forcedConfigPath= base_config_file) + + if self.model is not None or self.model2 is not None: + from wgp import save_model + from mmgp.safetensors2 import torch_load_file else: - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + if self.transformer_switch: + if 0 in submodel_no_list[2:] and 1 in submodel_no_list[2:]: + raise Exception("Shared and non shared modules at the same time across multipe models is not supported") + + if 0 in submodel_no_list[2:]: + shared_modules= {} + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = model_filename[2:], modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath, return_shared_modules= shared_modules) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = shared_modules, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + shared_modules = None + else: + modules_for_1 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==1 ] + modules_for_2 =[ file_name for file_name, submodel_no in zip(model_filename[2:],submodel_no_list[2:] ) if submodel_no ==2 ] + self.model = offload.fast_load_transformers_model(model_filename[:1], modules = modules_for_1, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + self.model2 = offload.fast_load_transformers_model(model_filename[1:2], modules = modules_for_2, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) + + else: + self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer and not save_quantized, writable_tensors= False, defaultConfigPath=base_config_file , forcedConfigPath= forcedConfigPath) - # self.model = offload.load_model_data(self.model, xmodel_filename ) - # offload.load_model_data(self.model, "c:/temp/Phantom-Wan-1.3B.pth") - self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) - offload.change_dtype(self.model, dtype, True) + if self.model is not None: + self.model.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) + offload.change_dtype(self.model, dtype, True) + self.model.eval().requires_grad_(False) if self.model2 is not None: self.model2.lock_layers_dtypes(torch.float32 if mixed_precision_transformer else dtype) offload.change_dtype(self.model2, dtype, True) - - # offload.save_model(self.model, "wan2.1_text2video_1.3B_mbf16.safetensors", do_quantize= False, config_file_path=base_config_file, filter_sd=sd) - # offload.save_model(self.model, "wan2.2_image2video_14B_low_mbf16.safetensors", config_file_path=base_config_file) - # offload.save_model(self.model, "wan2.2_image2video_14B_low_quanto_mbf16_int8.safetensors", do_quantize=True, config_file_path=base_config_file) - self.model.eval().requires_grad_(False) - if self.model2 is not None: self.model2.eval().requires_grad_(False) + if module_source is not None: - from wgp import save_model - from mmgp.safetensors2 import torch_load_file - filter = list(torch_load_file(module_source)) - save_model(self.model, model_type, dtype, None, is_module=True, filter=filter) - elif not source is None: - from wgp import save_model - save_model(self.model, model_type, dtype, None) + save_model(self.model, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source)), module_source_no=1) + if module_source2 is not None: + save_model(self.model2, model_type, dtype, None, is_module=True, filter=list(torch_load_file(module_source2)), module_source_no=2) + if not source is None: + save_model(self.model, model_type, dtype, None, submodel_no= 1) + if not source2 is None: + save_model(self.model2, model_type, dtype, None, submodel_no= 2) if save_quantized: from wgp import save_quantized_model - save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) + if self.model is not None: + save_quantized_model(self.model, model_type, model_filename[0], dtype, base_config_file) if self.model2 is not None: save_quantized_model(self.model2, model_type, model_filename[1], dtype, base_config_file, submodel_no=2) self.sample_neg_prompt = config.sample_neg_prompt - if self.model.config.get("vace_in_dim", None) != None: - self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), - min_area=480*832, - max_area=480*832, - min_fps=config.sample_fps, - max_fps=config.sample_fps, - zero_start=True, - seq_len=32760, - keep_last=True) - + if hasattr(self.model, "vace_blocks"): self.adapt_vace_model(self.model) if self.model2 is not None: self.adapt_vace_model(self.model2) + if hasattr(self.model, "face_adapter"): + self.adapt_animate_model(self.model) + if self.model2 is not None: self.adapt_animate_model(self.model2) + self.num_timesteps = 1000 self.use_timestep_transform = True def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, overlapped_latents = None): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) + ref_images = [ref_images] * len(frames) if masks is None: latents = self.vae.encode(frames, tile_size = tile_size) @@ -225,11 +236,7 @@ def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0, over return cat_latents def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - + ref_images = [ref_images] * len(masks) result_masks = [] for mask, refs in zip(masks, ref_images): c, depth, height, width = mask.shape @@ -257,119 +264,6 @@ def vace_encode_masks(self, masks, ref_images=None): result_masks.append(mask) return result_masks - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - def fit_image_into_canvas(self, ref_img, image_size, canvas_tf_bg, device, fill_max = False, outpainting_dims = None, return_mask = False): - from shared.utils.utils import save_image - ref_width, ref_height = ref_img.size - if (ref_height, ref_width) == image_size and outpainting_dims == None: - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - canvas = torch.zeros_like(ref_img) if return_mask else None - else: - if outpainting_dims != None: - final_height, final_width = image_size - canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) - else: - canvas_height, canvas_width = image_size - scale = min(canvas_height / ref_height, canvas_width / ref_width) - new_height = int(ref_height * scale) - new_width = int(ref_width * scale) - if fill_max and (canvas_height - new_height) < 16: - new_height = canvas_height - if fill_max and (canvas_width - new_width) < 16: - new_width = canvas_width - top = (canvas_height - new_height) // 2 - left = (canvas_width - new_width) // 2 - ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - if outpainting_dims != None: - canvas = torch.full((3, 1, final_height, final_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img - else: - canvas = torch.full((3, 1, canvas_height, canvas_width), canvas_tf_bg, dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, top:top + new_height, left:left + new_width] = ref_img - ref_img = canvas - canvas = None - if return_mask: - if outpainting_dims != None: - canvas = torch.ones((3, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 - else: - canvas = torch.ones((3, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] - canvas[:, :, top:top + new_height, left:left + new_width] = 0 - canvas = canvas.to(device) - return ref_img.to(device), canvas - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, keep_video_guide_frames= [], start_frame = 0, fit_into_canvas = None, pre_src_video = None, inject_frames = [], outpainting_dims = None, any_background_ref = False): - image_sizes = [] - trim_video_guide = len(keep_video_guide_frames) - def conv_tensor(t, device): - return t.float().div_(127.5).add_(-1).permute(3, 0, 1, 2).to(device) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - num_frames = min(num_frames, trim_video_guide) if trim_video_guide > 0 and sub_src_video != None else num_frames - if sub_src_mask is not None and sub_src_video is not None: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = conv_tensor(sub_src_mask[:num_frames], device) - # src_video is [-1, 1] (at this function output), 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1] (at this function output), 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.full_like(sub_pre_src_video, -1.0), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, total_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i] = conv_tensor(sub_src_video[:num_frames], device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_video_guide_frames): - if not keep: - pos = prepend_count + k - src_video[i][:, pos:pos+1] = 0 - src_mask[i][:, pos:pos+1] = 1 - - for k, frame in enumerate(inject_frames): - if frame != None: - pos = prepend_count + k - src_video[i][:, pos:pos+1], src_mask[i][:, pos:pos+1] = self.fit_image_into_canvas(frame, image_size, 0, device, True, outpainting_dims, return_mask= True) - - - self.background_mask = None - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None and not torch.is_tensor(ref_img): - if j==0 and any_background_ref: - if self.background_mask == None: self.background_mask = [None] * len(src_ref_images) - src_ref_images[i][j], self.background_mask[i] = self.fit_image_into_canvas(ref_img, image_size, 0, device, True, outpainting_dims, return_mask= True) - else: - src_ref_images[i][j], _ = self.fit_image_into_canvas(ref_img, image_size, 1, device) - if self.background_mask != None: - self.background_mask = [ item if item != None else self.background_mask[0] for item in self.background_mask ] # deplicate background mask with double control net since first controlnet image ref modifed by ref - return src_video, src_mask, src_ref_images def get_vae_latents(self, ref_images, device, tile_size= 0): ref_vae_latents = [] @@ -380,12 +274,53 @@ def get_vae_latents(self, ref_images, device, tile_size= 0): return torch.cat(ref_vae_latents, dim=1) + def get_i2v_mask(self, lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=None, lat_t =0, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = F.interpolate(mask_pixel_values.to(device), size=(lat_h, lat_w), mode='nearest') + + if nb_frames_unchanged >0: + msk[:, :nb_frames_unchanged] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1,2)[0] + return msk + + def encode_reference_images(self, ref_images, ref_prompt="image of a face", any_guidance= False, tile_size = None): + ref_images = [convert_image_to_tensor(img).unsqueeze(1).to(device=self.device, dtype=self.dtype) for img in ref_images] + shape = ref_images[0].shape + freqs = get_rotary_pos_embed( (len(ref_images) , shape[-2] // 8, shape[-1] // 8 )) + # batch_ref_image: [B, C, F, H, W] + vae_feat = self.vae.encode(ref_images, tile_size = tile_size) + vae_feat = torch.cat( vae_feat, dim=1).unsqueeze(0) + if any_guidance: + vae_feat_uncond = self.vae.encode([ref_images[0] * 0], tile_size = tile_size) * len(ref_images) + vae_feat_uncond = torch.cat( vae_feat_uncond, dim=1).unsqueeze(0) + context = self.text_encoder([ref_prompt], self.device)[0].to(self.dtype) + context = torch.cat([context, context.new_zeros(self.model.text_len -context.size(0), context.size(1)) ]).unsqueeze(0) + clear_caches() + get_cache("lynx_ref_buffer").update({ 0: {}, 1: {} }) + ref_buffer = self.model( + pipeline =self, + x = [vae_feat, vae_feat_uncond] if any_guidance else [vae_feat], + context = [context, context] if any_guidance else [context], + freqs= freqs, + t=torch.stack([torch.tensor(0, dtype=torch.float)]).to(self.device), + lynx_feature_extractor = True, + ) + clear_caches() + return ref_buffer[0], (ref_buffer[1] if any_guidance else None) def generate(self, input_prompt, input_frames= None, + input_frames2= None, input_masks = None, - input_ref_images = None, + input_masks2 = None, + input_ref_images = None, + input_ref_masks = None, + input_faces = None, input_video = None, image_start = None, image_end = None, @@ -444,6 +379,8 @@ def generate(self, pre_video_frame = None, video_prompt_type= "", original_input_ref_images = [], + face_arc_embeds = None, + control_scale_alt = 1., **bbargs ): @@ -453,7 +390,8 @@ def generate(self, timesteps.append(0.) timesteps = [torch.tensor([t], device=self.device) for t in timesteps] if self.use_timestep_transform: - timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1] + timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps][:-1] + timesteps = torch.tensor(timesteps) sample_scheduler = None elif sample_solver == 'causvid': sample_scheduler = FlowMatchScheduler(num_inference_steps=sampling_steps, shift=shift, sigma_min=0, extra_one_step=True) @@ -489,13 +427,17 @@ def generate(self, # Text Encoder if n_prompt == "": n_prompt = self.sample_neg_prompt - context = self.text_encoder([input_prompt], self.device)[0] - context_null = self.text_encoder([n_prompt], self.device)[0] - context = context.to(self.dtype) - context_null = context_null.to(self.dtype) text_len = self.model.text_len - context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) - context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) + any_guidance_at_all = guide_scale > 1 or guide2_scale > 1 and guide_phases >=2 or guide3_scale > 1 and guide_phases >=3 + context = self.text_encoder([input_prompt], self.device)[0].to(self.dtype) + context = torch.cat([context, context.new_zeros(text_len -context.size(0), context.size(1)) ]).unsqueeze(0) + if NAG_scale > 1 or any_guidance_at_all: + context_null = self.text_encoder([n_prompt], self.device)[0].to(self.dtype) + context_null = torch.cat([context_null, context_null.new_zeros(text_len -context_null.size(0), context_null.size(1)) ]).unsqueeze(0) + else: + context_null = None + if input_video is not None: height, width = input_video.shape[-2:] + # NAG_prompt = "static, low resolution, blurry" # context_NAG = self.text_encoder([NAG_prompt], self.device)[0] # context_NAG = context_NAG.to(self.dtype) @@ -509,116 +451,84 @@ def generate(self, # if NAG_scale > 1: context = torch.cat([context, context_NAG], dim=0) if self._interrupt: return None - vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B"] + vace = model_type in ["vace_1.3B","vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] phantom = model_type in ["phantom_1.3B", "phantom_14B"] fantasy = model_type in ["fantasy"] multitalk = model_type in ["multitalk", "infinitetalk", "vace_multitalk_14B", "i2v_2_2_multitalk"] infinitetalk = model_type in ["infinitetalk"] standin = model_type in ["standin", "vace_standin_14B"] + lynx = model_type in ["lynx_lite", "lynx", "vace_lynx_lite_14B", "vace_lynx_14B"] recam = model_type in ["recam_1.3B"] - ti2v = model_type in ["ti2v_2_2"] + ti2v = model_type in ["ti2v_2_2", "lucy_edit"] + lucy_edit= model_type in ["lucy_edit"] + animate= model_type in ["animate"] start_step_no = 0 ref_images_count = 0 trim_frames = 0 - extended_overlapped_latents = None + extended_overlapped_latents = clip_image_start = clip_image_end = image_mask_latents = None no_noise_latents_injection = infinitetalk timestep_injection = False lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + extended_input_dim = 0 + ref_images_before = False # image2video if model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "flf2v_720p"]: any_end_frame = False - if image_start is None: - if infinitetalk: - if input_frames is not None: - image_ref = input_frames[:, -1] - if input_video is None: input_video = input_frames[:, -1:] - new_shot = "Q" in video_prompt_type - else: - if pre_video_frame is None: - new_shot = True - else: - if input_ref_images is None: - input_ref_images, new_shot = [pre_video_frame], False - else: - input_ref_images, new_shot = [img.resize(pre_video_frame.size, resample=Image.Resampling.LANCZOS) for img in input_ref_images], "Q" in video_prompt_type - if input_ref_images is None: raise Exception("Missing Reference Image") - new_shot = new_shot and window_no <= len(input_ref_images) - image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) - if new_shot: - input_video = image_ref.unsqueeze(1) - else: - color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot - _ , preframes_count, height, width = input_video.shape - input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) - if infinitetalk: - image_for_clip = image_ref.to(input_video) - control_pre_frames_count = 1 - control_video = image_for_clip.unsqueeze(1) + if infinitetalk: + new_shot = "0" in video_prompt_type + if input_frames is not None: + image_ref = input_frames[:, 0] else: - image_for_clip = input_video[:, -1] - control_pre_frames_count = preframes_count - control_video = input_video - lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - clip_image = resize_lanczos(image_for_clip, clip_image_size, clip_image_size)[:, None, :, :] - clip_context = self.clip.visual([clip_image]) if model_type != "flf2v_720p" else self.clip.visual([clip_image , clip_image ]) - clip_image = None + if input_ref_images is None: + if pre_video_frame is None: raise Exception("Missing Reference Image") + input_ref_images, new_shot = [pre_video_frame], False + new_shot = new_shot and window_no <= len(input_ref_images) + image_ref = convert_image_to_tensor(input_ref_images[ min(window_no, len(input_ref_images))-1 ]) + if new_shot or input_video is None: + input_video = image_ref.unsqueeze(1) else: - clip_context = None - enc = torch.concat( [control_video, torch.zeros( (3, frame_num-control_pre_frames_count, height, width), - device=self.device, dtype= self.VAE_dtype)], - dim = 1).to(self.device) - color_reference_frame = image_for_clip.unsqueeze(1).clone() + color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot + _ , preframes_count, height, width = input_video.shape + input_video = input_video.to(device=self.device).to(dtype= self.VAE_dtype) + if infinitetalk: + image_start = image_ref.to(input_video) + control_pre_frames_count = 1 + control_video = image_start.unsqueeze(1) else: - preframes_count = control_pre_frames_count = 1 - any_end_frame = image_end is not None - add_frames_for_end_image = any_end_frame and model_type == "i2v" - if any_end_frame: - if add_frames_for_end_image: - frame_num +=1 - lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) - trim_frames = 1 - - height, width = image_start.shape[1:] - - lat_h = round( - height // self.vae_stride[1] // - self.patch_size[1] * self.patch_size[1]) - lat_w = round( - width // self.vae_stride[2] // - self.patch_size[2] * self.patch_size[2]) - height = lat_h * self.vae_stride[1] - width = lat_w * self.vae_stride[2] - image_start_frame = image_start.unsqueeze(1).to(self.device) - color_reference_frame = image_start_frame.clone() - if image_end is not None: - img_end_frame = image_end.unsqueeze(1).to(self.device) - - if hasattr(self, "clip"): - clip_image_size = self.clip.model.image_size - image_start = resize_lanczos(image_start, clip_image_size, clip_image_size) - if image_end is not None: image_end = resize_lanczos(image_end, clip_image_size, clip_image_size) - if model_type == "flf2v_720p": - clip_context = self.clip.visual([image_start[:, None, :, :], image_end[:, None, :, :] if image_end is not None else image_start[:, None, :, :]]) - else: - clip_context = self.clip.visual([image_start[:, None, :, :]]) - else: - clip_context = None - - if any_end_frame: - enc= torch.concat([ - image_start_frame, - torch.zeros( (3, frame_num-2, height, width), device=self.device, dtype= self.VAE_dtype), - img_end_frame, - ], dim=1).to(self.device) - else: - enc= torch.concat([ - image_start_frame, - torch.zeros( (3, frame_num-1, height, width), device=self.device, dtype= self.VAE_dtype) - ], dim=1).to(self.device) + image_start = input_video[:, -1] + control_pre_frames_count = preframes_count + control_video = input_video - image_start = image_end = image_start_frame = img_end_frame = image_for_clip = image_ref = None + color_reference_frame = image_start.unsqueeze(1).clone() + + any_end_frame = image_end is not None + add_frames_for_end_image = any_end_frame and model_type == "i2v" + if any_end_frame: + color_correction_strength = 0 #disable color correction as transition frames between shots may have a complete different color level than the colors of the new shot + if add_frames_for_end_image: + frame_num +=1 + lat_frames = int((frame_num - 2) // self.vae_stride[0] + 2) + trim_frames = 1 + + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + + if image_end is not None: + img_end_frame = image_end.unsqueeze(1).to(self.device) + clip_image_start, clip_image_end = image_start, image_end + + if any_end_frame: + enc= torch.concat([ + control_video, + torch.zeros( (3, frame_num-control_pre_frames_count-1, height, width), device=self.device, dtype= self.VAE_dtype), + img_end_frame, + ], dim=1).to(self.device) + else: + enc= torch.concat([ + control_video, + torch.zeros( (3, frame_num-control_pre_frames_count, height, width), device=self.device, dtype= self.VAE_dtype) + ], dim=1).to(self.device) + + image_start = image_end = img_end_frame = image_ref = control_video = None msk = torch.ones(1, frame_num, lat_h, lat_w, device=self.device) if any_end_frame: @@ -643,42 +553,87 @@ def generate(self, if infinitetalk: lat_y = self.vae.encode([input_video], VAE_tile_size)[0] extended_overlapped_latents = lat_y[:, :overlapped_latents_frames_num].clone().unsqueeze(0) - # if control_pre_frames_count != pre_frames_count: lat_y = input_video = None kwargs.update({ 'y': y}) - if not clip_context is None: - kwargs.update({'clip_fea': clip_context}) - # Recam Master + # Animate + if animate: + pose_pixels = input_frames * input_masks + input_masks = 1. - input_masks + pose_pixels -= input_masks + pose_latents = self.vae.encode([pose_pixels], VAE_tile_size)[0].unsqueeze(0) + input_frames = input_frames * input_masks + if not "X" in video_prompt_type: input_frames += input_masks - 1 # masked area should black (-1) in background frames + # input_frames = input_frames[:, :1].expand(-1, input_frames.shape[1], -1, -1) + if prefix_frames_count > 0: + input_frames[:, :prefix_frames_count] = input_video + input_masks[:, :prefix_frames_count] = 1 + # save_video(pose_pixels, "pose.mp4") + # save_video(input_frames, "input_frames.mp4") + # save_video(input_masks, "input_masks.mp4", value_range=(0,1)) + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + msk_ref = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=1,lat_t=1, device=self.device) + msk_control = self.get_i2v_mask(lat_h, lat_w, nb_frames_unchanged=0, mask_pixel_values=input_masks, device=self.device) + msk = torch.concat([msk_ref, msk_control], dim=1) + image_ref = input_ref_images[0].to(self.device) + clip_image_start = image_ref.squeeze(1) + lat_y = torch.concat(self.vae.encode([image_ref, input_frames.to(self.device)], VAE_tile_size), dim=1) + y = torch.concat([msk, lat_y]) + kwargs.update({ 'y': y, 'pose_latents': pose_latents}) + face_pixel_values = input_faces.unsqueeze(0) + lat_y = msk = msk_control = msk_ref = pose_pixels = None + ref_images_before = True + ref_images_count = 1 + lat_frames = int((input_frames.shape[1] - 1) // self.vae_stride[0]) + 1 + + # Clip image + if hasattr(self, "clip") and clip_image_start is not None: + clip_image_size = self.clip.model.image_size + clip_image_start = resize_lanczos(clip_image_start, clip_image_size, clip_image_size) + clip_image_end = resize_lanczos(clip_image_end, clip_image_size, clip_image_size) if clip_image_end is not None else clip_image_start + if model_type == "flf2v_720p": + clip_context = self.clip.visual([clip_image_start[:, None, :, :], clip_image_end[:, None, :, :] if clip_image_end is not None else clip_image_start[:, None, :, :]]) + else: + clip_context = self.clip.visual([clip_image_start[:, None, :, :]]) + clip_image_start = clip_image_end = None + kwargs.update({'clip_fea': clip_context}) + + # Recam Master & Lucy Edit + if recam or lucy_edit: + frame_num, height,width = input_frames.shape[-3:] + lat_frames = int((frame_num - 1) // self.vae_stride[0]) + 1 + frame_num = (lat_frames -1) * self.vae_stride[0] + 1 + input_frames = input_frames[:, :frame_num].to(dtype=self.dtype , device=self.device) + extended_latents = self.vae.encode([input_frames])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) + extended_input_dim = 2 if recam else 1 + del input_frames + if recam: - # should be be in fact in input_frames since it is control video not a video to be extended - target_camera = model_mode - height,width = input_video.shape[-2:] - input_video = input_video.to(dtype=self.dtype , device=self.device) - source_latents = self.vae.encode([input_video])[0].unsqueeze(0) #.to(dtype=self.dtype, device=self.device) - del input_video # Process target camera (recammaster) + target_camera = model_mode from shared.utils.cammmaster_tools import get_camera_embedding cam_emb = get_camera_embedding(target_camera) cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) kwargs['cam_emb'] = cam_emb # Video 2 Video - if denoising_strength < 1. and input_frames != None: + if "G" in video_prompt_type and input_frames != None: height, width = input_frames.shape[-2:] source_latents = self.vae.encode([input_frames])[0].unsqueeze(0) injection_denoising_step = 0 inject_from_start = False if input_frames != None and denoising_strength < 1 : color_reference_frame = input_frames[:, -1:].clone() - if overlapped_latents != None: - overlapped_latents_frames_num = overlapped_latents.shape[2] - overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 + if prefix_frames_count > 0: + overlapped_frames_num = prefix_frames_count + overlapped_latents_frames_num = (overlapped_frames_num -1 // 4) + 1 + # overlapped_latents_frames_num = overlapped_latents.shape[2] + # overlapped_frames_num = (overlapped_latents_frames_num-1) * 4 + 1 else: overlapped_latents_frames_num = overlapped_frames_num = 0 if len(keep_frames_parsed) == 0 or image_outputs or (overlapped_frames_num + len(keep_frames_parsed)) == input_frames.shape[1] and all(keep_frames_parsed) : keep_frames_parsed = [] - injection_denoising_step = int(sampling_steps * (1. - denoising_strength) ) + injection_denoising_step = int( round(sampling_steps * (1. - denoising_strength),4) ) latent_keep_frames = [] if source_latents.shape[2] < lat_frames or len(keep_frames_parsed) > 0: inject_from_start = True @@ -694,14 +649,21 @@ def generate(self, if hasattr(sample_scheduler, "sigmas"): sample_scheduler.sigmas= sample_scheduler.sigmas[injection_denoising_step:] injection_denoising_step = 0 + if input_masks is not None and not "U" in video_prompt_type: + image_mask_latents = torch.nn.functional.interpolate(input_masks, size= source_latents.shape[-2:], mode="nearest").unsqueeze(0) + if image_mask_latents.shape[2] !=1: + image_mask_latents = torch.cat([ image_mask_latents[:,:, :1], torch.nn.functional.interpolate(image_mask_latents, size= (source_latents.shape[-3]-1, *source_latents.shape[-2:]), mode="nearest") ], dim=2) + image_mask_latents = torch.where(image_mask_latents>=0.5, 1., 0. )[:1].to(self.device) + # save_video(image_mask_latents.squeeze(0), "mama.mp4", value_range=(0,1) ) + # image_mask_rebuilt = image_mask_latents.repeat_interleave(8, dim=-1).repeat_interleave(8, dim=-2).unsqueeze(0) + # Phantom if phantom: - input_ref_images_neg = None - if input_ref_images != None: # Phantom Ref images - input_ref_images = self.get_vae_latents(input_ref_images, self.device) - input_ref_images_neg = torch.zeros_like(input_ref_images) - ref_images_count = input_ref_images.shape[1] if input_ref_images != None else 0 - trim_frames = input_ref_images.shape[1] + lat_input_ref_images_neg = None + if input_ref_images is not None: # Phantom Ref images + lat_input_ref_images = self.get_vae_latents(input_ref_images, self.device) + lat_input_ref_images_neg = torch.zeros_like(lat_input_ref_images) + ref_images_count = trim_frames = lat_input_ref_images.shape[1] if ti2v: if input_video is None: @@ -710,28 +672,31 @@ def generate(self, height, width = input_video.shape[-2:] source_latents = self.vae.encode([input_video], tile_size = VAE_tile_size)[0].unsqueeze(0) timestep_injection = True + if extended_input_dim > 0: + extended_latents[:, :, :source_latents.shape[2]] = source_latents # Vace if vace : # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] - if self.background_mask != None: self.background_mask = [m.to(self.device) for m in self.background_mask] + input_frames = [input_frames.to(self.device)] +([] if input_frames2 is None else [input_frames2.to(self.device)]) + input_masks = [input_masks.to(self.device)] + ([] if input_masks2 is None else [input_masks2.to(self.device)]) + if model_type in ["vace_lynx_14B"]: + input_ref_images = input_ref_masks = None + input_ref_images = None if input_ref_images is None else [ u.to(self.device) for u in input_ref_images] + input_ref_masks = None if input_ref_masks is None else [ None if u is None else u.to(self.device) for u in input_ref_masks] + ref_images_before = True z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size, overlapped_latents = overlapped_latents ) m0 = self.vace_encode_masks(input_masks, input_ref_images) - if self.background_mask != None: - color_reference_frame = input_ref_images[0][0].clone() - zbg = self.vace_encode_frames([ref_img[0] for ref_img in input_ref_images], None, masks=self.background_mask, tile_size = VAE_tile_size ) - mbg = self.vace_encode_masks(self.background_mask, None) + if input_ref_masks is not None and len(input_ref_masks) > 0 and input_ref_masks[0] is not None: + color_reference_frame = input_ref_images[0].clone() + zbg = self.vace_encode_frames( input_ref_images[:1] * len(input_frames), None, masks=input_ref_masks[0], tile_size = VAE_tile_size ) + mbg = self.vace_encode_masks(input_ref_masks[:1] * len(input_frames), None) for zz0, mm0, zzbg, mmbg in zip(z0, m0, zbg, mbg): zz0[:, 0:1] = zzbg mm0[:, 0:1] = mmbg - - self.background_mask = zz0 = mm0 = zzbg = mmbg = None - z = self.vace_latent(z0, m0) - - ref_images_count = len(input_ref_images[0]) if input_ref_images != None and input_ref_images[0] != None else 0 + zz0 = mm0 = zzbg = mmbg = None + z = [torch.cat([zz, mm], dim=0) for zz, mm in zip(z0, m0)] + ref_images_count = len(input_ref_images) if input_ref_images is not None and input_ref_images is not None else 0 context_scale = context_scale if context_scale != None else [1.0] * len(z) kwargs.update({'vace_context' : z, 'vace_context_scale' : context_scale, "ref_images_count": ref_images_count }) if overlapped_latents != None : @@ -739,17 +704,12 @@ def generate(self, extended_overlapped_latents = z[0][:16, :overlapped_latents_size + ref_images_count].clone().unsqueeze(0) if prefix_frames_count > 0: color_reference_frame = input_frames[0][:, prefix_frames_count -1:prefix_frames_count].clone() + lat_h, lat_w = height // self.vae_stride[1], width // self.vae_stride[2] + target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, lat_h, lat_w) - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - lat_h, lat_w = target_shape[-2:] - height = self.vae_stride[1] * lat_h - width = self.vae_stride[2] * lat_w - - else: - target_shape = (self.vae.model.z_dim, lat_frames + ref_images_count, height // self.vae_stride[1], width // self.vae_stride[2]) - - if multitalk and audio_proj != None: + if multitalk: + if audio_proj is None: + audio_proj = [ torch.zeros( (1, 1, 5, 12, 768 ), dtype=self.dtype, device=self.device), torch.zeros( (1, (frame_num - 1) // 4, 8, 12, 768 ), dtype=self.dtype, device=self.device) ] from .multitalk.multitalk import get_target_masks audio_proj = [audio.to(self.dtype) for audio in audio_proj] human_no = len(audio_proj[0]) @@ -764,15 +724,47 @@ def generate(self, expand_shape = [batch_size] + [-1] * len(target_shape) # Ropes - if target_camera != None: + if extended_input_dim>=2: shape = list(target_shape[1:]) - shape[0] *= 2 + shape[extended_input_dim-2] *= 2 freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) else: freqs = get_rotary_pos_embed(target_shape[1:], enable_RIFLEx= enable_RIFLEx) kwargs["freqs"] = freqs + # Lynx + if lynx : + if original_input_ref_images is None or len(original_input_ref_images) == 0: + lynx = False + else: + from .lynx.resampler import Resampler + from accelerate import init_empty_weights + lynx_lite = model_type in ["lynx_lite", "vace_lynx_lite_14B"] + ip_hidden_states = ip_hidden_states_uncond = None + if True: + with init_empty_weights(): + arc_resampler = Resampler( depth=4, dim=1280, dim_head=64, embedding_dim=512, ff_mult=4, heads=20, num_queries=16, output_dim=2048 if lynx_lite else 5120 ) + offload.load_model_data(arc_resampler, os.path.join("ckpts", "wan2.1_lynx_lite_arc_resampler.safetensors" if lynx_lite else "wan2.1_lynx_full_arc_resampler.safetensors")) + arc_resampler.to(self.device) + arcface_embed = face_arc_embeds[None,None,:].to(device=self.device, dtype=torch.float) + ip_hidden_states = arc_resampler(arcface_embed).to(self.dtype) + ip_hidden_states_uncond = arc_resampler(torch.zeros_like(arcface_embed)).to(self.dtype) + arc_resampler = None + if not lynx_lite: + standin_ref_pos = -1 + image_ref = original_input_ref_images[standin_ref_pos] + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + lynx_ref = face_processor.process(image_ref, resize_to = 256 ) + lynx_ref_buffer, lynx_ref_buffer_uncond = self.encode_reference_images([lynx_ref], tile_size=VAE_tile_size, any_guidance= any_guidance_at_all) + lynx_ref = None + gc.collect() + torch.cuda.empty_cache() + vace_lynx = model_type in ["vace_lynx_14B"] + kwargs["lynx_ip_scale"] = control_scale_alt + kwargs["lynx_ref_scale"] = control_scale_alt + #Standin if standin: from preprocessing.face_preprocessor import FaceProcessor @@ -809,18 +801,27 @@ def generate(self, if callback != None: callback(-1, None, True) + + clear_caches() offload.shared_state["_chipmunk"] = False chipmunk = offload.shared_state.get("_chipmunk", False) if chipmunk: self.model.setup_chipmunk() + offload.shared_state["_radial"] = offload.shared_state["_attention"]=="radial" + radial = offload.shared_state.get("_radial", False) + if radial: + radial_cache = get_cache("radial") + from shared.radial_attention.attention import fill_radial_cache + fill_radial_cache(radial_cache, len(self.model.blocks), *target_shape[1:]) + # init denoising updated_num_steps= len(timesteps) denoising_extra = "" from shared.utils.loras_mutipliers import update_loras_slists, get_model_switch_steps - phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(timesteps, updated_num_steps, guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) + phase_switch_step, phase_switch_step2, phases_description = get_model_switch_steps(original_timesteps,guide_phases, 0 if self.model2 is None else model_switch_phase, switch_threshold, switch2_threshold ) if len(phases_description) > 0: set_header_text(phases_description) guidance_switch_done = guidance_switch2_done = False if guide_phases > 1: denoising_extra = f"Phase 1/{guide_phases} High Noise" if self.model2 is not None else f"Phase 1/{guide_phases}" @@ -831,8 +832,8 @@ def update_guidance(step_no, t, guide_scale, new_guide_scale, guidance_switch_do denoising_extra = f"Phase {phase_no}/{guide_phases} {'Low Noise' if trans == self.model2 else 'High Noise'}" if self.model2 is not None else f"Phase {phase_no}/{guide_phases}" callback(step_no-1, denoising_extra = denoising_extra) return guide_scale, guidance_switch_done, trans, denoising_extra - update_loras_slists(self.model, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) - if self.model2 is not None: update_loras_slists(self.model2, loras_slists, updated_num_steps, phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + update_loras_slists(self.model, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) + if self.model2 is not None: update_loras_slists(self.model2, loras_slists, len(original_timesteps), phase_switch_step= phase_switch_step, phase_switch_step2= phase_switch_step2) callback(-1, None, True, override_num_inference_steps = updated_num_steps, denoising_extra = denoising_extra) def clear(): @@ -845,19 +846,22 @@ def clear(): scheduler_kwargs = {} if isinstance(sample_scheduler, FlowMatchScheduler) else {"generator": seed_g} # b, c, lat_f, lat_h, lat_w latents = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) + if "G" in video_prompt_type: randn = latents if apg_switch != 0: apg_momentum = -0.75 apg_norm_threshold = 55 text_momentumbuffer = MomentumBuffer(apg_momentum) audio_momentumbuffer = MomentumBuffer(apg_momentum) - + input_frames = input_frames2 = input_masks =input_masks2 = input_video = input_ref_images = input_ref_masks = pre_video_frame = None + gc.collect() + torch.cuda.empty_cache() # denoising trans = self.model for i, t in enumerate(tqdm(timesteps)): guide_scale, guidance_switch_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide2_scale, guidance_switch_done, switch_threshold, trans, 2, denoising_extra) guide_scale, guidance_switch2_done, trans, denoising_extra = update_guidance(i, t, guide_scale, guide3_scale, guidance_switch2_done, switch2_threshold, trans, 3, denoising_extra) - offload.set_step_no_for_lora(trans, i) + offload.set_step_no_for_lora(trans, start_step_no + i) timestep = torch.stack([t]) if timestep_injection: @@ -865,44 +869,43 @@ def clear(): timestep = torch.full((target_shape[-3],), t, dtype=torch.int64, device=latents.device) timestep[:source_latents.shape[2]] = 0 - kwargs.update({"t": timestep, "current_step": start_step_no + i}) + kwargs.update({"t": timestep, "current_step_no": i, "real_step_no": start_step_no + i }) kwargs["slg_layers"] = slg_layers if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps) else None - if denoising_strength < 1 and input_frames != None and i <= injection_denoising_step: + if denoising_strength < 1 and i <= injection_denoising_step: sigma = t / 1000 - noise = torch.randn(batch_size, *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) if inject_from_start: - new_latents = latents.clone() - new_latents[:,:, :source_latents.shape[2] ] = noise[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents + noisy_image = latents.clone() + noisy_image[:,:, :source_latents.shape[2] ] = randn[:, :, :source_latents.shape[2] ] * sigma + (1 - sigma) * source_latents for latent_no, keep_latent in enumerate(latent_keep_frames): if not keep_latent: - new_latents[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] - latents = new_latents - new_latents = None + noisy_image[:, :, latent_no:latent_no+1 ] = latents[:, :, latent_no:latent_no+1] + latents = noisy_image + noisy_image = None else: - latents = noise * sigma + (1 - sigma) * source_latents - noise = None + latents = randn * sigma + (1 - sigma) * source_latents if extended_overlapped_latents != None: if no_noise_latents_injection: latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents else: latent_noise_factor = t / 1000 + latents[:, :, :extended_overlapped_latents.shape[2]] = extended_overlapped_latents * (1.0 - latent_noise_factor) + torch.randn_like(extended_overlapped_latents ) * latent_noise_factor if vace: overlap_noise_factor = overlap_noise / 1000 for zz in z: zz[0:16, ref_images_count:extended_overlapped_latents.shape[2] ] = extended_overlapped_latents[0, :, ref_images_count:] * (1.0 - overlap_noise_factor) + torch.randn_like(extended_overlapped_latents[0, :, ref_images_count:] ) * overlap_noise_factor - if target_camera != None: - latent_model_input = torch.cat([latents, source_latents.expand(*expand_shape)], dim=2) + if extended_input_dim > 0: + latent_model_input = torch.cat([latents, extended_latents.expand(*expand_shape)], dim=extended_input_dim) else: latent_model_input = latents any_guidance = guide_scale != 1 if phantom: gen_args = { - "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + - [ torch.cat([latent_model_input[:,:, :-ref_images_count], input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), + "x" : ([ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images.unsqueeze(0).expand(*expand_shape)], dim=2) ] * 2 + + [ torch.cat([latent_model_input[:,:, :-ref_images_count], lat_input_ref_images_neg.unsqueeze(0).expand(*expand_shape)], dim=2)]), "context": [context, context_null, context_null] , } elif fantasy: @@ -911,6 +914,21 @@ def clear(): "context" : [context, context_null, context_null], "audio_scale": [audio_scale, None, None ] } + elif animate: + gen_args = { + "x" : [latent_model_input, latent_model_input], + "context" : [context, context_null], + "face_pixel_values": [face_pixel_values, None] + } + elif lynx: + gen_args = { + "x" : [latent_model_input, latent_model_input], + "context" : [context, context_null], + "lynx_ip_embeds": [ip_hidden_states, ip_hidden_states_uncond] + } + if model_type in ["lynx", "vace_lynx_14B"]: + gen_args["lynx_ref_buffer"] = [lynx_ref_buffer, lynx_ref_buffer_uncond] + elif multitalk and audio_proj != None: if guide_scale == 1: gen_args = { @@ -1011,8 +1029,8 @@ def clear(): if sample_solver == "euler": dt = timesteps[i] if i == len(timesteps)-1 else (timesteps[i] - timesteps[i + 1]) - dt = dt / self.num_timesteps - latents = latents - noise_pred * dt[:, None, None, None, None] + dt = dt.item() / self.num_timesteps + latents = latents - noise_pred * dt else: latents = sample_scheduler.step( noise_pred[:, :, :target_shape[1]], @@ -1020,9 +1038,16 @@ def clear(): latents, **scheduler_kwargs)[0] + + if image_mask_latents is not None: + sigma = 0 if i == len(timesteps)-1 else timesteps[i+1]/1000 + noisy_image = randn * sigma + (1 - sigma) * source_latents + latents = noisy_image * (1-image_mask_latents) + image_mask_latents * latents + + if callback is not None: latents_preview = latents - if vace and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] + if ref_images_before and ref_images_count > 0: latents_preview = latents_preview[:, :, ref_images_count: ] if trim_frames > 0: latents_preview= latents_preview[:, :,:-trim_frames] if image_outputs: latents_preview= latents_preview[:, :,:1] if len(latents_preview) > 1: latents_preview = latents_preview.transpose(0,2) @@ -1033,7 +1058,7 @@ def clear(): if timestep_injection: latents[:, :, :source_latents.shape[2]] = source_latents - if vace and ref_images_count > 0: latents = latents[:, :, ref_images_count:] + if ref_images_before and ref_images_count > 0: latents = latents[:, :, ref_images_count:] if trim_frames > 0: latents= latents[:, :,:-trim_frames] if return_latent_slice != None: latent_slice = latents[:, :, return_latent_slice].clone() @@ -1070,4 +1095,20 @@ def adapt_vace_model(self, model): delattr(model, "vace_blocks") + def adapt_animate_model(self, model): + modules_dict= { k: m for k, m in model.named_modules()} + for animate_layer in range(8): + module = modules_dict[f"face_adapter.fuser_blocks.{animate_layer}"] + model_layer = animate_layer * 5 + target = modules_dict[f"blocks.{model_layer}"] + setattr(target, "face_adapter_fuser_blocks", module ) + delattr(model, "face_adapter") + + def get_loras_transformer(self, get_model_recursive_prop, base_model_type, model_type, video_prompt_type, model_mode, **kwargs): + if base_model_type == "animate": + if "#" in video_prompt_type and "1" in video_prompt_type: + preloadURLs = get_model_recursive_prop(model_type, "preload_URLs") + if len(preloadURLs) > 0: + return [os.path.join("ckpts", os.path.basename(preloadURLs[0]))] , [1] + return [], [] diff --git a/models/wan/df_handler.py b/models/wan/df_handler.py index bc79e2ebc..31d5da6c8 100644 --- a/models/wan/df_handler.py +++ b/models/wan/df_handler.py @@ -21,11 +21,24 @@ def query_model_def(base_model_type, model_def): extra_model_def["fps"] =fps extra_model_def["frames_minimum"] = 17 extra_model_def["frames_steps"] = 20 + extra_model_def["latent_size"] = 4 extra_model_def["sliding_window"] = True extra_model_def["skip_layer_guidance"] = True extra_model_def["tea_cache"] = True extra_model_def["guidance_max_phases"] = 1 + extra_model_def["model_modes"] = { + "choices": [ + ("Synchronous", 0), + ("Asynchronous (better quality but around 50% extra steps added)", 5), + ], + "default": 0, + "label" : "Generation Type" + } + + extra_model_def["image_prompt_types_allowed"] = "TSV" + + return extra_model_def @staticmethod @@ -54,7 +67,11 @@ def query_model_family(): def query_family_infos(): return {} - + @staticmethod + def get_rgb_factors(base_model_type ): + from shared.RGB_factors import get_rgb_factors + latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) + return latent_rgb_factors, latent_rgb_factors_bias @staticmethod def query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization): @@ -62,7 +79,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder return family_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None): from .configs import WAN_CONFIGS from .wan_handler import family_handler cfg = WAN_CONFIGS['t2v-14B'] diff --git a/models/wan/diffusion_forcing copy.py b/models/wan/diffusion_forcing copy.py deleted file mode 100644 index 753fd4561..000000000 --- a/models/wan/diffusion_forcing copy.py +++ /dev/null @@ -1,479 +0,0 @@ -import math -import os -from typing import List -from typing import Optional -from typing import Tuple -from typing import Union -import logging -import numpy as np -import torch -from diffusers.image_processor import PipelineImageInput -from diffusers.utils.torch_utils import randn_tensor -from diffusers.video_processor import VideoProcessor -from tqdm import tqdm -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler - -class DTT2V: - - - def __init__( - self, - config, - checkpoint_dir, - rank=0, - model_filename = None, - text_encoder_filename = None, - quantizeTransformer = False, - dtype = torch.bfloat16, - ): - self.device = torch.device(f"cuda") - self.config = config - self.rank = rank - self.dtype = dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn= None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - - - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanModel from {model_filename}") - from mmgp import offload - - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False, forcedConfigPath="config.json") - # offload.load_model_data(self.model, "recam.ckpt") - # self.model.cpu() - # offload.save_model(self.model, "recam.safetensors") - if self.dtype == torch.float16 and not "fp16" in model_filename: - self.model.to(self.dtype) - # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) - if self.dtype == torch.float16: - self.vae.model.to(self.dtype) - self.model.eval().requires_grad_(False) - - self.scheduler = FlowUniPCMultistepScheduler() - - @property - def do_classifier_free_guidance(self) -> bool: - return self._guidance_scale > 1 - - def encode_image( - self, image: PipelineImageInput, height: int, width: int, num_frames: int, tile_size = 0, causal_block_size = 0 - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - # prefix_video - prefix_video = np.array(image.resize((width, height))).transpose(2, 0, 1) - prefix_video = torch.tensor(prefix_video).unsqueeze(1) # .to(image_embeds.dtype).unsqueeze(1) - if prefix_video.dtype == torch.uint8: - prefix_video = (prefix_video.float() / (255.0 / 2.0)) - 1.0 - prefix_video = prefix_video.to(self.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0), tile_size = tile_size)[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] - return prefix_video, predix_video_latent_length - - def prepare_latents( - self, - shape: Tuple[int], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - ) -> torch.Tensor: - return randn_tensor(shape, generator, device=device, dtype=dtype) - - def generate_timestep_matrix( - self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: - step_matrix, step_index = [], [] - update_mask, valid_interval = [], [] - num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size - if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block - assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) - step_template = torch.cat( - [ - torch.tensor([999], dtype=torch.int64, device=step_template.device), - step_template.long(), - torch.tensor([0], dtype=torch.int64, device=step_template.device), - ] - ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) - if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations - - while torch.all(pre_row >= (num_iterations - 1)) == False: - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): - if i == 0 or pre_row[i - 1] >= ( - num_iterations - 1 - ): # the first frame or the last frame is completely denoised - new_row[i] = pre_row[i] + 1 - else: - new_row[i] = new_row[i - 1] - ar_step - new_row = new_row.clamp(0, num_iterations) - - update_mask.append( - (new_row != pre_row) & (new_row != num_iterations) - ) # False: no need to update, True: need to update - step_index.append(new_row) - step_matrix.append(step_template[new_row]) - pre_row = new_row - - # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block - if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) - update_mask = update_mask[0] - update_mask_idx = idx_sequence[update_mask] - last_update_idx = update_mask_idx[-1].item() - terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): - for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: - terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) - - step_update_mask = torch.stack(update_mask, dim=0) - step_index = torch.stack(step_index, dim=0) - step_matrix = torch.stack(step_matrix, dim=0) - - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] - - return step_matrix, step_index, step_update_mask, valid_interval - - @torch.no_grad() - def generate( - self, - prompt: Union[str, List[str]], - negative_prompt: Union[str, List[str]] = "", - image: PipelineImageInput = None, - height: int = 480, - width: int = 832, - num_frames: int = 97, - num_inference_steps: int = 50, - shift: float = 1.0, - guidance_scale: float = 5.0, - seed: float = 0.0, - overlap_history: int = 17, - addnoise_condition: int = 0, - base_num_frames: int = 97, - ar_step: int = 5, - causal_block_size: int = 1, - causal_attention: bool = False, - fps: int = 24, - VAE_tile_size = 0, - joint_pass = False, - callback = None, - ): - generator = torch.Generator(device=self.device) - generator.manual_seed(seed) - # if base_num_frames > base_num_frames: - # causal_block_size = 0 - self._guidance_scale = guidance_scale - - i2v_extra_kwrags = {} - prefix_video = None - predix_video_latent_length = 0 - if image: - frame_width, frame_height = image.size - scale = min(height / frame_height, width / frame_width) - height = (int(frame_height * scale) // 16) * 16 - width = (int(frame_width * scale) // 16) * 16 - - prefix_video, predix_video_latent_length = self.encode_image(image, height, width, num_frames, tile_size=VAE_tile_size, causal_block_size=causal_block_size) - - latent_length = (num_frames - 1) // 4 + 1 - latent_height = height // 8 - latent_width = width // 8 - - prompt_embeds = self.text_encoder([prompt], self.device) - prompt_embeds = [u.to(self.dtype).to(self.device) for u in prompt_embeds] - if self.do_classifier_free_guidance: - negative_prompt_embeds = self.text_encoder([negative_prompt], self.device) - negative_prompt_embeds = [u.to(self.dtype).to(self.device) for u in negative_prompt_embeds] - - - - self.scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - init_timesteps = self.scheduler.timesteps - fps_embeds = [fps] * prompt_embeds[0].shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - transformer_dtype = self.dtype - # with torch.cuda.amp.autocast(dtype=self.dtype), torch.no_grad(): - if overlap_history is None or base_num_frames is None or num_frames <= base_num_frames: - # short video generation - latent_shape = [16, latent_length, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=torch.float32, device=self.device, generator=generator - ) - latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - latent_length, init_timesteps, base_num_frames, ar_step, predix_video_latent_length, causal_block_size - ) - sample_schedulers = [] - for _ in range(latent_length): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * latent_length - - if callback != None: - callback(-1, None, True) - - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] * (1.0 - noise_factor) - + torch.randn_like(latent_model_input[0][:, valid_interval_start:predix_video_latent_length]) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - # "causal_block_size" : causal_block_size, - "callback" : callback, - "pipeline" : self - } - kwrags.update(i2v_extra_kwrags) - - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=prompt_embeds, - context2=negative_prompt_embeds, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=negative_prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0], False) - - x0 = latents[0].unsqueeze(0) - videos = self.vae.decode(x0, tile_size= VAE_tile_size) - videos = (videos / 2 + 0.5).clamp(0, 1) - videos = [video for video in videos] - videos = [video.permute(1, 2, 3, 0) * 255 for video in videos] - videos = [video.cpu().numpy().astype(np.uint8) for video in videos] - return videos - else: - # long video generation - base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length - overlap_history_frames = (overlap_history - 1) // 4 + 1 - n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1 - print(f"n_iter:{n_iter}") - output_video = None - for i in range(n_iter): - if output_video is not None: # i !=0 - prefix_video = output_video[:, -overlap_history:].to(self.device) - prefix_video = [self.vae.encode(prefix_video.unsqueeze(0))[0]] # [(c, f, h, w)] - if prefix_video[0].shape[1] % causal_block_size != 0: - truncate_len = prefix_video[0].shape[1] % causal_block_size - print("the length of prefix video is truncated for the casual block size alignment.") - prefix_video[0] = prefix_video[0][:, : prefix_video[0].shape[1] - truncate_len] - predix_video_latent_length = prefix_video[0].shape[1] - finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames - left_frame_num = latent_length - finished_frame_num - base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames) - else: # i == 0 - base_num_frames_iter = base_num_frames - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] - latents = self.prepare_latents( - latent_shape, dtype=torch.float32, device=self.device, generator=generator - ) - latents = [latents] - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - init_timesteps, - base_num_frames_iter, - ar_step, - predix_video_latent_length, - causal_block_size, - ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(num_inference_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter - if callback != None: - callback(-1, None, True) - - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= False) - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - "causal_block_size" : causal_block_size, - "causal_attention" : causal_attention, - "callback" : callback, - "pipeline" : self - } - kwrags.update(i2v_extra_kwrags) - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=prompt_embeds, - context2=negative_prompt_embeds, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=prompt_embeds, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=negative_prompt_embeds, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=generator, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0].squeeze(0), False) - - x0 = latents[0].unsqueeze(0) - videos = [self.vae.decode(x0, tile_size= VAE_tile_size)[0]] - if output_video is None: - output_video = videos[0].clamp(-1, 1).cpu() # c, f, h, w - else: - output_video = torch.cat( - [output_video, videos[0][:, overlap_history:].clamp(-1, 1).cpu()], 1 - ) # c, f, h, w - return output_video diff --git a/models/wan/lynx/__init__.py b/models/wan/lynx/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/models/wan/lynx/attention_processor.py b/models/wan/lynx/attention_processor.py new file mode 100644 index 000000000..c64291d1d --- /dev/null +++ b/models/wan/lynx/attention_processor.py @@ -0,0 +1,41 @@ +# Copyright (c) 2025 The Wan Team and The HuggingFace Team. +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# +# Original file was released under Apache License 2.0, with the full license text +# available at https://github.com/huggingface/diffusers/blob/v0.30.3/LICENSE and https://github.com/Wan-Video/Wan2.1/blob/main/LICENSE.txt. +# +# This modified file is released under the same license. + + +from typing import Optional, List +import torch +import torch.nn as nn +from diffusers.models.normalization import RMSNorm + +def setup_lynx_attention_layers(blocks, lynx_full, dim): + if lynx_full: + lynx_cross_dim = 5120 + lynx_layers = len(blocks) + else: + lynx_cross_dim = 2048 + lynx_layers = 20 + for i, block in enumerate(blocks): + if i < lynx_layers: + block.cross_attn.to_k_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full) + block.cross_attn.to_v_ip = nn.Linear(lynx_cross_dim, dim , bias=lynx_full) + else: + block.cross_attn.to_k_ip = None + block.cross_attn.to_v_ip = None + if lynx_full: + block.cross_attn.registers = nn.Parameter(torch.randn(1, 16, lynx_cross_dim) / dim**0.5) + block.cross_attn.norm_rms_k = None + block.self_attn.to_k_ref = nn.Linear(dim, dim, bias=True) + block.self_attn.to_v_ref = nn.Linear(dim, dim, bias=True) + else: + block.cross_attn.registers = None + block.cross_attn.norm_rms_k = RMSNorm(dim, eps=1e-5, elementwise_affine=False) + + + diff --git a/models/wan/lynx/navit_utils.py b/models/wan/lynx/navit_utils.py new file mode 100644 index 000000000..593e7f0d1 --- /dev/null +++ b/models/wan/lynx/navit_utils.py @@ -0,0 +1,16 @@ +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import torch + +def vector_to_list(tensor, lens, dim): + return list(torch.split(tensor, lens, dim=dim)) + +def list_to_vector(tensor_list, dim): + lens = [tensor.shape[dim] for tensor in tensor_list] + tensor = torch.cat(tensor_list, dim) + return tensor, lens + +def merge_token_lists(list1, list2, dim): + assert(len(list1) == len(list2)) + return [torch.cat((t1, t2), dim) for t1, t2 in zip(list1, list2)] \ No newline at end of file diff --git a/models/wan/lynx/resampler.py b/models/wan/lynx/resampler.py new file mode 100644 index 000000000..e225a062b --- /dev/null +++ b/models/wan/lynx/resampler.py @@ -0,0 +1,185 @@ + +# Copyright (c) 2022 Phil Wang +# Copyright (c) 2023 Anas Awadalla, Irena Gao, Joshua Gardner, Jack Hessel, Yusuf Hanafy, Wanrong Zhu, Kalyani Marathe, Yonatan Bitton, Samir Gadre, Jenia Jitsev, Simon Kornblith, Pang Wei Koh, Gabriel Ilharco, Mitchell Wortsman, Ludwig Schmidt. +# Copyright (c) 2023 Tencent AI Lab +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates + +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# SPDX-License-Identifier: Apache-2.0 + +# Original file (IP-Adapter) was released under Apache License 2.0, with the full license text +# available at https://github.com/tencent-ailab/IP-Adapter/blob/main/LICENSE. + +# Original file (open_flamingo and flamingo-pytorch) was released under MIT License: + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import math + +import torch +import torch.nn as nn + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + #(bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1024, + depth=8, + dim_head=64, + heads=16, + num_queries=8, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence + **kwargs, + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.pretrained = kwargs.pop("pretrained", None) + self.freeze = kwargs.pop("freeze", False) + self.skip_last_layer = kwargs.pop("skip_last_layer", False) + + if self.freeze and self.pretrained is None: + raise ValueError("Trying to freeze the model but no pretrained provided!") + + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) diff --git a/models/wan/modules/model.py b/models/wan/modules/model.py index b9c98bf81..30f6172f0 100644 --- a/models/wan/modules/model.py +++ b/models/wan/modules/model.py @@ -12,29 +12,17 @@ import numpy as np from typing import Union,Optional from mmgp import offload +from mmgp.offload import get_cache, clear_caches from shared.attention import pay_attention from torch.backends.cuda import sdp_kernel from ..multitalk.multitalk_utils import get_attn_map_with_target +from ..animate.motion_encoder import Generator +from ..animate.face_blocks import FaceAdapter, FaceEncoder +from ..animate.model_animate import after_patch_embedding __all__ = ['WanModel'] -def get_cache(cache_name): - all_cache = offload.shared_state.get("_cache", None) - if all_cache is None: - all_cache = {} - offload.shared_state["_cache"]= all_cache - cache = offload.shared_state.get(cache_name, None) - if cache is None: - cache = {} - offload.shared_state[cache_name] = cache - return cache - -def clear_caches(): - all_cache = offload.shared_state.get("_cache", None) - if all_cache is not None: - all_cache.clear() - def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 @@ -143,7 +131,7 @@ def __init__(self, dim, eps=1e-5): self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) - def forward(self, x): + def forward(self, x, in_place= True): r""" Args: x(Tensor): Shape [B, L, C] @@ -153,7 +141,10 @@ def forward(self, x): y = y.mean(dim=-1, keepdim=True) y += self.eps y.rsqrt_() - x *= y + if in_place: + x *= y + else: + x = x * y x *= self.weight return x # return self._norm(x).type_as(x) * self.weight @@ -286,7 +277,7 @@ def text_cross_attention(self, xlist, context, return_q = False): else: return x, None - def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1): + def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks = None, ref_images_count = 0, standin_phase =-1, lynx_ref_buffer = None, lynx_ref_scale = 0, sub_x_no=0): r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] @@ -299,7 +290,21 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim # query, key, value function - q, k, v = self.q(x), self.k(x), self.v(x) + q = self.q(x) + ref_hidden_states = None + if not lynx_ref_buffer is None: + lynx_ref_features = lynx_ref_buffer[self.block_no] + if self.norm_q is not None: ref_query = self.norm_q(q, in_place = False) + ref_key = self.to_k_ref(lynx_ref_features) + ref_value = self.to_v_ref(lynx_ref_features) + if self.norm_k is not None: ref_key = self.norm_k(ref_key) + ref_query, ref_key, ref_value = ref_query.unflatten(2, (self.num_heads, -1)), ref_key.unflatten(2, (self.num_heads, -1)), ref_value.unflatten(2, (self.num_heads, -1)) + qkv_list = [ref_query, ref_key, ref_value ] + del ref_query, ref_key, ref_value + ref_hidden_states = pay_attention(qkv_list) + + k, v = self.k(x), self.v(x) + if standin_phase == 1: q += self.q_loras(x) k += self.k_loras(x) @@ -327,6 +332,8 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks x_ref_attn_map = None chipmunk = offload.shared_state.get("_chipmunk", False) + radial = offload.shared_state.get("_radial", False) + if chipmunk and self.__class__ == WanSelfAttention: q = q.transpose(1,2) k = k.transpose(1,2) @@ -334,12 +341,17 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks attn_layers = offload.shared_state["_chipmunk_layers"] x = attn_layers[self.block_no](q, k, v) x = x.transpose(1,2) + elif radial and self.__class__ == WanSelfAttention: + qkv_list = [q,k,v] + del q,k,v + radial_cache = get_cache("radial") + no_step_no = offload.shared_state["step_no"] + x = radial_cache[self.block_no](qkv_list=qkv_list, timestep_no=no_step_no) elif block_mask == None: qkv_list = [q,k,v] del q,k,v - x = pay_attention( - qkv_list, - window_size=self.window_size) + + x = pay_attention( qkv_list, window_size=self.window_size) else: with sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): @@ -352,6 +364,8 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks ) del q,k,v + if ref_hidden_states is not None: + x.add_(ref_hidden_states, alpha= lynx_ref_scale) x = x.flatten(2) x = self.o(x) @@ -360,13 +374,38 @@ def forward(self, xlist, grid_sizes, freqs, block_mask = None, ref_target_masks class WanT2VCrossAttention(WanSelfAttention): - def forward(self, xlist, context, grid_sizes, *args, **kwargs): + def forward(self, xlist, context, grid_sizes, lynx_ip_embeds = None, lynx_ip_scale = 0, *args, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] """ - x, _ = self.text_cross_attention( xlist, context) + x, q = self.text_cross_attention( xlist, context, return_q=lynx_ip_embeds is not None) + if lynx_ip_embeds is not None and self.to_k_ip is not None: + if self.registers is not None: + from ..lynx.navit_utils import vector_to_list, merge_token_lists, list_to_vector + ip_hidden_states_list = vector_to_list(lynx_ip_embeds, lynx_ip_embeds.shape[1], 1) + ip_hidden_states_list = merge_token_lists(ip_hidden_states_list, [self.registers] * len(ip_hidden_states_list), 1) + lynx_ip_embeds, ip_lens = list_to_vector(ip_hidden_states_list, 1) + ip_hidden_states_list = None + ip_key = self.to_k_ip(lynx_ip_embeds) + ip_value = self.to_v_ip(lynx_ip_embeds) + # if self.norm_q is not None: ip_query = self.norm_q(q) + if self.norm_rms_k is None: + ip_key = self.norm_k(ip_key) + else: + ip_key = self.norm_rms_k(ip_key) + ip_inner_dim = ip_key.shape[-1] + ip_head_dim = ip_inner_dim // self.num_heads + batch_size = q.shape[0] + # ip_query = ip_query.view(batch_size, -1, attn.heads, ip_head_dim) + ip_key = ip_key.view(batch_size, -1, self.num_heads, ip_head_dim) + ip_value = ip_value.view(batch_size, -1, self.num_heads, ip_head_dim) + qkv_list = [q, ip_key, ip_value] + del q, ip_key, ip_value + ip_hidden_states = pay_attention(qkv_list).reshape(*x.shape) + x.add_(ip_hidden_states, alpha= lynx_ip_scale) + x = self.o(x) return x @@ -387,7 +426,7 @@ def __init__(self, # self.alpha = nn.Parameter(torch.zeros((1, ))) self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() - def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens ): + def forward(self, xlist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens, *args, **kwargs ): r""" Args: x(Tensor): Shape [B, L1, C] @@ -514,6 +553,13 @@ def forward( multitalk_masks=None, ref_images_count=0, standin_phase=-1, + motion_vec = None, + lynx_ip_embeds = None, + lynx_ip_scale = 0, + lynx_ref_scale = 0, + lynx_feature_extractor = False, + lynx_ref_buffer = None, + sub_x_no =0, ): r""" Args: @@ -548,6 +594,7 @@ def forward( x_mod *= 1 + e[1] x_mod += e[0] x_mod = restore_latent_shape(x_mod) + if cam_emb != None: cam_emb = self.cam_encoder(cam_emb) cam_emb = cam_emb.repeat(1, 2, 1) @@ -556,8 +603,9 @@ def forward( x_mod += cam_emb xlist = [x_mod.to(attention_dtype)] + if lynx_feature_extractor: get_cache("lynx_ref_buffer")[sub_x_no][self.block_no] = xlist[0] del x_mod - y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, ) + y, x_ref_attn_map = self.self_attn( xlist, grid_sizes, freqs, block_mask = block_mask, ref_target_masks = multitalk_masks, ref_images_count = ref_images_count, standin_phase= standin_phase, lynx_ref_buffer = lynx_ref_buffer, lynx_ref_scale = lynx_ref_scale, sub_x_no = sub_x_no) y = y.to(dtype) if cam_emb != None: y = self.projector(y) @@ -572,26 +620,30 @@ def forward( y = y.to(attention_dtype) ylist= [y] del y - x += self.cross_attn(ylist, context, grid_sizes, audio_proj, audio_scale, audio_context_lens).to(dtype) + x += self.cross_attn(ylist, context, grid_sizes, audio_proj = audio_proj, audio_scale = audio_scale, audio_context_lens = audio_context_lens, lynx_ip_embeds=lynx_ip_embeds, lynx_ip_scale=lynx_ip_scale).to(dtype) if multitalk_audio != None: # cross attn of multitalk audio y = self.norm_x(x) y = y.to(attention_dtype) if ref_images_count == 0: - x += self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) + ylist= [y] + del y + x += self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes, x_ref_attn_map=x_ref_attn_map) else: y_shape = y.shape y = y.reshape(y_shape[0], grid_sizes[0], -1) y = y[:, ref_images_count:] y = y.reshape(y_shape[0], -1, y_shape[-1]) grid_sizes_alt = [grid_sizes[0]-ref_images_count, *grid_sizes[1:]] - y = self.audio_cross_attn(y, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) + ylist= [y] + y = None + y = self.audio_cross_attn(ylist, encoder_hidden_states=multitalk_audio, shape=grid_sizes_alt, x_ref_attn_map=x_ref_attn_map) y = y.reshape(y_shape[0], grid_sizes[0]-ref_images_count, -1) x = x.reshape(y_shape[0], grid_sizes[0], -1) x[:, ref_images_count:] += y x = x.reshape(y_shape[0], -1, y_shape[-1]) - del y + del y y = self.norm2(x) @@ -627,6 +679,10 @@ def forward( x.add_(hint) else: x.add_(hint, alpha= scale) + + if motion_vec is not None and self.block_no % 5 == 0: + x += self.face_adapter_fuser_blocks(x.to(self.face_adapter_fuser_blocks.linear1_kv.weight.dtype), motion_vec, None, False) + return x class AudioProjModel(ModelMixin, ConfigMixin): @@ -909,6 +965,8 @@ def __init__(self, norm_input_visual=True, norm_output_audio=True, standin= False, + motion_encoder_dim=0, + lynx=None, ): super().__init__() @@ -933,14 +991,15 @@ def __init__(self, self.flag_causal_attention = False self.block_mask = None self.inject_sample_info = inject_sample_info - + self.motion_encoder_dim = motion_encoder_dim self.norm_output_audio = norm_output_audio self.audio_window = audio_window self.intermediate_dim = intermediate_dim self.vae_scale = vae_scale multitalk = multitalk_output_dim > 0 - self.multitalk = multitalk + self.multitalk = multitalk + animate = motion_encoder_dim > 0 # embeddings self.patch_embedding = nn.Conv3d( @@ -1038,6 +1097,30 @@ def __init__(self, block.self_attn.k_loras = LoRALinearLayer(dim, dim, rank=128) block.self_attn.v_loras = LoRALinearLayer(dim, dim, rank=128) + if lynx is not None: + from ..lynx.attention_processor import setup_lynx_attention_layers + lynx_full = lynx=="full" + setup_lynx_attention_layers(self.blocks, lynx_full, dim) + + if animate: + self.pose_patch_embedding = nn.Conv3d( + 16, dim, kernel_size=patch_size, stride=patch_size + ) + + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter( + heads_num=self.num_heads, + hidden_dim=self.dim, + num_adapter_layers=self.num_layers // 5, + ) + + self.face_encoder = FaceEncoder( + in_dim=motion_encoder_dim, + hidden_dim=self.dim, + num_heads=4, + ) + + def lock_layers_dtypes(self, hybrid_dtype = None, dtype = torch.float32): layer_list = [self.head, self.head.head, self.patch_embedding] target_dype= dtype @@ -1202,7 +1285,8 @@ def forward( y=None, freqs = None, pipeline = None, - current_step = 0, + current_step_no = 0, + real_step_no = 0, x_id= 0, max_steps = 0, slg_layers=None, @@ -1219,6 +1303,14 @@ def forward( ref_images_count = 0, standin_freqs = None, standin_ref = None, + pose_latents=None, + face_pixel_values=None, + lynx_ip_embeds = None, + lynx_ip_scale = 0, + lynx_ref_scale = 0, + lynx_feature_extractor = False, + lynx_ref_buffer = None, + ): # patch_dtype = self.patch_embedding.weight.dtype modulation_dtype = self.time_projection[1].weight.dtype @@ -1251,9 +1343,20 @@ def forward( if bz > 1: y = y.expand(bz, -1, -1, -1, -1) x = torch.cat([x, y], dim=1) # embeddings - # x = self.patch_embedding(x.unsqueeze(0)).to(modulation_dtype) x = self.patch_embedding(x).to(modulation_dtype) grid_sizes = x.shape[2:] + x_list[i] = x + y = None + + motion_vec_list = [] + if face_pixel_values is None: face_pixel_values = [None] * len(x_list) + for i, (x, one_face_pixel_values) in enumerate(zip(x_list, face_pixel_values)): + # animate embeddings + motion_vec = None + if pose_latents is not None: + # x, motion_vec = after_patch_embedding(self, x, pose_latents, torch.zeros_like(face_pixel_values[0]) if one_face_pixel_values is None else one_face_pixel_values) + x, motion_vec = after_patch_embedding(self, x, pose_latents, face_pixel_values[0]) # if one_face_pixel_values is None else one_face_pixel_values) + motion_vec_list.append(motion_vec) if chipmunk: x = x.unsqueeze(-1) x_og_shape = x.shape @@ -1261,7 +1364,7 @@ def forward( else: x = x.flatten(2).transpose(1, 2) x_list[i] = x - x, y = None, None + x = None block_mask = None @@ -1280,9 +1383,8 @@ def forward( del causal_mask offload.shared_state["embed_sizes"] = grid_sizes - offload.shared_state["step_no"] = current_step + offload.shared_state["step_no"] = current_step_no offload.shared_state["max_steps"] = max_steps - if current_step == 0 and x_id == 0: clear_caches() # arguments kwargs = dict( @@ -1293,6 +1395,9 @@ def forward( audio_proj=audio_proj, audio_context_lens=audio_context_lens, ref_images_count=ref_images_count, + lynx_ip_scale= lynx_ip_scale, + lynx_ref_scale = lynx_ref_scale, + lynx_feature_extractor = lynx_feature_extractor, ) _flag_df = t.dim() == 2 @@ -1306,11 +1411,22 @@ def forward( if standin_ref is not None: standin_cache_enabled = False kwargs["standin_phase"] = 2 - if current_step == 0 or not standin_cache_enabled : + if current_step_no == 0 or not standin_cache_enabled : standin_x = self.patch_embedding(standin_ref).to(modulation_dtype).flatten(2).transpose(1, 2) standin_e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, torch.zeros_like(t)).to(modulation_dtype) ) standin_e0 = self.time_projection(standin_e).unflatten(1, (6, self.dim)).to(e.dtype) standin_e = standin_ref = None + + if lynx_ip_embeds is None: + lynx_ip_embeds_list = [None] * len(x_list) + else: + lynx_ip_embeds_list = lynx_ip_embeds + + if lynx_ref_buffer is None: + lynx_ref_buffer_list = [None] * len(x_list) + else: + lynx_ref_buffer_list = lynx_ref_buffer + if self.inject_sample_info and fps!=None: fps = torch.tensor(fps, dtype=torch.long, device=device) @@ -1371,7 +1487,7 @@ def forward( skips_steps_cache = self.cache if skips_steps_cache != None: if skips_steps_cache.cache_type == "mag": - if current_step <= skips_steps_cache.start_step: + if real_step_no <= skips_steps_cache.start_step: should_calc = True elif skips_steps_cache.one_for_all and x_id != 0: # not joint pass, not main pas, one for all assert len(x_list) == 1 @@ -1380,7 +1496,7 @@ def forward( x_should_calc = [] for i in range(1 if skips_steps_cache.one_for_all else len(x_list)): cur_x_id = i if joint_pass else x_id - cur_mag_ratio = skips_steps_cache.mag_ratios[current_step * 2 + cur_x_id] # conditional and unconditional in one list + cur_mag_ratio = skips_steps_cache.mag_ratios[real_step_no * 2 + cur_x_id] # conditional and unconditional in one list skips_steps_cache.accumulated_ratio[cur_x_id] *= cur_mag_ratio # magnitude ratio between current step and the cached step skips_steps_cache.accumulated_steps[cur_x_id] += 1 # skip steps plus 1 cur_skip_err = np.abs(1-skips_steps_cache.accumulated_ratio[cur_x_id]) # skip error of current steps @@ -1400,7 +1516,7 @@ def forward( if x_id != 0: should_calc = skips_steps_cache.should_calc else: - if current_step <= skips_steps_cache.start_step or current_step == skips_steps_cache.num_steps-1: + if real_step_no <= skips_steps_cache.start_step or real_step_no == skips_steps_cache.num_steps-1: should_calc = True skips_steps_cache.accumulated_rel_l1_distance = 0 else: @@ -1461,11 +1577,11 @@ def forward( continue x_list[0] = block(x_list[0], context = context_list[0], audio_scale= audio_scale_list[0], e= e0, **kwargs) else: - for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc)): + for i, (x, context, hints, audio_scale, multitalk_audio, multitalk_masks, should_calc, motion_vec, lynx_ip_embeds,lynx_ref_buffer) in enumerate(zip(x_list, context_list, hints_list, audio_scale_list, multitalk_audio_list, multitalk_masks_list, x_should_calc,motion_vec_list, lynx_ip_embeds_list,lynx_ref_buffer_list)): if should_calc: - x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, **kwargs) + x_list[i] = block(x, context = context, hints= hints, audio_scale= audio_scale, multitalk_audio = multitalk_audio, multitalk_masks =multitalk_masks, e= e0, motion_vec = motion_vec, lynx_ip_embeds= lynx_ip_embeds, lynx_ref_buffer = lynx_ref_buffer, sub_x_no =i, **kwargs) del x - context = hints = audio_embedding = None + context = hints = None if skips_steps_cache != None: if joint_pass: @@ -1489,7 +1605,9 @@ def forward( torch.sub(x_list[0], ori_hidden_states[0], out=residual) skips_steps_cache.previous_residual[x_id] = residual residual, ori_hidden_states = None, None - + if lynx_feature_extractor: + return get_cache("lynx_ref_buffer") + for i, x in enumerate(x_list): if chipmunk: x = reverse_voxel_chunk_no_padding(x.transpose(1, 2).unsqueeze(-1), x_og_shape, voxel_shape).squeeze(-1) diff --git a/models/wan/multitalk/attention.py b/models/wan/multitalk/attention.py index 27d488f6d..669b5c1a7 100644 --- a/models/wan/multitalk/attention.py +++ b/models/wan/multitalk/attention.py @@ -221,13 +221,16 @@ def __init__( self.add_q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.add_k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() - def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: + def forward(self, xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, enable_sp=False, kv_seq=None) -> torch.Tensor: N_t, N_h, N_w = shape + x = xlist[0] + xlist.clear() x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) # get q for hidden_state B, N, C = x.shape q = self.q_linear(x) + del x q_shape = (B, N, self.num_heads, self.head_dim) q = q.view(q_shape).permute((0, 2, 1, 3)) @@ -247,9 +250,6 @@ def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=No q = rearrange(q, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K") - - attn_bias = None - # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=attn_bias, op=None,) qkv_list = [q, encoder_k, encoder_v] q = encoder_k = encoder_v = None x = pay_attention(qkv_list) @@ -302,7 +302,7 @@ def __init__( self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) def forward(self, - x: torch.Tensor, + xlist: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None, x_ref_attn_map=None, @@ -310,14 +310,17 @@ def forward(self, encoder_hidden_states = encoder_hidden_states.squeeze(0) if x_ref_attn_map == None: - return super().forward(x, encoder_hidden_states, shape) + return super().forward(xlist, encoder_hidden_states, shape) N_t, _, _ = shape + x = xlist[0] + xlist.clear() x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) # get q for hidden_state B, N, C = x.shape q = self.q_linear(x) + del x q_shape = (B, N, self.num_heads, self.head_dim) q = q.view(q_shape).permute((0, 2, 1, 3)) @@ -339,7 +342,9 @@ def forward(self, normalized_pos = normalized_map[range(x_ref_attn_map.size(1)), max_indices] # N q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) - q = self.rope_1d(q, normalized_pos) + qlist = [q] + del q + q = self.rope_1d(qlist, normalized_pos, "q") q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) _, N_a, _ = encoder_hidden_states.shape @@ -347,7 +352,7 @@ def forward(self, encoder_kv_shape = (B, N_a, 2, self.num_heads, self.head_dim) encoder_kv = encoder_kv.view(encoder_kv_shape).permute((2, 0, 3, 1, 4)) encoder_k, encoder_v = encoder_kv.unbind(0) - + del encoder_kv if self.qk_norm: encoder_k = self.add_k_norm(encoder_k) @@ -356,13 +361,14 @@ def forward(self, per_frame[per_frame.size(0)//2:] = (self.rope_h2[0] + self.rope_h2[1]) / 2 encoder_pos = torch.concat([per_frame]*N_t, dim=0) encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) - encoder_k = self.rope_1d(encoder_k, encoder_pos) + enclist = [encoder_k] + del encoder_k + encoder_k = self.rope_1d(enclist, encoder_pos, "encoder_k") encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) q = rearrange(q, "B H M K -> B M H K") encoder_k = rearrange(encoder_k, "B H M K -> B M H K") encoder_v = rearrange(encoder_v, "B H M K -> B M H K") - # x = xformers.ops.memory_efficient_attention(q, encoder_k, encoder_v, attn_bias=None, op=None,) qkv_list = [q, encoder_k, encoder_v] q = encoder_k = encoder_v = None x = pay_attention(qkv_list) diff --git a/models/wan/multitalk/multitalk.py b/models/wan/multitalk/multitalk.py index fbf9175ae..52d2dd995 100644 --- a/models/wan/multitalk/multitalk.py +++ b/models/wan/multitalk/multitalk.py @@ -59,7 +59,30 @@ def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=160 audio_emb = audio_emb.cpu().detach() return audio_emb - + +def extract_audio_from_video(filename, sample_rate): + raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav' + ffmpeg_command = [ + "ffmpeg", + "-y", + "-i", + str(filename), + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + "16000", + "-ac", + "2", + str(raw_audio_path), + ] + subprocess.run(ffmpeg_command, check=True) + human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate) + human_speech_array = loudness_norm(human_speech_array, sr) + os.remove(raw_audio_path) + + return human_speech_array + def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): ext = os.path.splitext(audio_path)[1].lower() if ext in ['.mp4', '.mov', '.avi', '.mkv']: @@ -191,18 +214,20 @@ def process_tts_multi(text, save_dir, voice1, voice2): return s1, s2, save_path_sum -def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0): +def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False): wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/chinese-wav2vec2-base") # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") pad = int(padded_frames_for_embeddings/ fps * sr) new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration ) - audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) - full_audio_embs = [] - if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) - # if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) - if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) - if audio_guide2 == None and not duration_changed: sum_human_speechs = None + if return_sum_only: + full_audio_embs = None + else: + audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) + full_audio_embs = [] + if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) + if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) + if audio_guide2 == None and not duration_changed: sum_human_speechs = None return full_audio_embs, sum_human_speechs diff --git a/models/wan/multitalk/multitalk_utils.py b/models/wan/multitalk/multitalk_utils.py index 6e2b2c307..7851c452b 100644 --- a/models/wan/multitalk/multitalk_utils.py +++ b/models/wan/multitalk/multitalk_utils.py @@ -16,7 +16,7 @@ import binascii import os.path as osp from skimage import color - +from mmgp.offload import get_cache, clear_caches VID_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv") ASPECT_RATIO_627 = { @@ -73,42 +73,70 @@ def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): # @torch.compile -def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count, mode='mean', attn_bias=None): - - ref_k = ref_k.to(visual_q.dtype).to(visual_q.device) +def calculate_x_ref_attn_map_per_head(visual_q, ref_k, ref_target_masks, ref_images_count, attn_bias=None): + dtype = visual_q.dtype + ref_k = ref_k.to(dtype).to(visual_q.device) scale = 1.0 / visual_q.shape[-1] ** 0.5 visual_q = visual_q * scale visual_q = visual_q.transpose(1, 2) ref_k = ref_k.transpose(1, 2) - attn = visual_q @ ref_k.transpose(-2, -1) + visual_q_shape = visual_q.shape + visual_q = visual_q.view(-1, visual_q_shape[-1] ) + number_chunks = visual_q_shape[-2]*ref_k.shape[-2] / 53090100 * 2 + chunk_size = int(visual_q_shape[-2] / number_chunks) + chunks =torch.split(visual_q, chunk_size) + maps_lists = [ [] for _ in ref_target_masks] + for q_chunk in chunks: + attn = q_chunk @ ref_k.transpose(-2, -1) + x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + del attn + ref_target_masks = ref_target_masks.to(dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(dtype) + + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask[None, None, None, ...] + x_ref_attnmap = x_ref_attn_map_source * ref_target_mask + x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens + maps_lists[class_idx].append(x_ref_attnmap) + + del x_ref_attn_map_source - if attn_bias is not None: attn += attn_bias + x_ref_attn_maps = [] + for class_idx, maps_list in enumerate(maps_lists): + attn_map_fuse = torch.concat(maps_list, dim= -1) + attn_map_fuse = attn_map_fuse.view(1, visual_q_shape[1], -1).squeeze(1) + x_ref_attn_maps.append( attn_map_fuse ) - x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + return torch.concat(x_ref_attn_maps, dim=0) + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, ref_images_count): + dtype = visual_q.dtype + ref_k = ref_k.to(dtype).to(visual_q.device) + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q * scale + visual_q = visual_q.transpose(1, 2) + ref_k = ref_k.transpose(1, 2) + attn = visual_q @ ref_k.transpose(-2, -1) + + x_ref_attn_map_source = attn.softmax(-1) # B, H, x_seqlens, ref_seqlens + del attn x_ref_attn_maps = [] - ref_target_masks = ref_target_masks.to(visual_q.dtype) - x_ref_attn_map_source = x_ref_attn_map_source.to(visual_q.dtype) + ref_target_masks = ref_target_masks.to(dtype) + x_ref_attn_map_source = x_ref_attn_map_source.to(dtype) for class_idx, ref_target_mask in enumerate(ref_target_masks): ref_target_mask = ref_target_mask[None, None, None, ...] x_ref_attnmap = x_ref_attn_map_source * ref_target_mask x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens x_ref_attnmap = x_ref_attnmap.permute(0, 2, 1) # B, x_seqlens, H - - if mode == 'mean': - x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens - elif mode == 'max': - x_ref_attnmap = x_ref_attnmap.max(-1) # B, x_seqlens - + x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens (mean of heads) x_ref_attn_maps.append(x_ref_attnmap) - del attn del x_ref_attn_map_source return torch.concat(x_ref_attn_maps, dim=0) - def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=10, ref_images_count = 0): """Args: query (torch.tensor): B M H K @@ -120,6 +148,11 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli N_t, N_h, N_w = shape x_seqlens = N_h * N_w + if x_seqlens <= 1508: + split_num = 10 # 540p + else: + split_num = 20 if x_seqlens <= 3600 else 40 # 720p / 1080p + ref_k = ref_k[:, :x_seqlens] if ref_images_count > 0 : visual_q_shape = visual_q.shape @@ -133,9 +166,14 @@ def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, spli split_chunk = heads // split_num - for i in range(split_num): - x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) - x_ref_attn_maps += x_ref_attn_maps_perhead + if split_chunk == 1: + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map_per_head(visual_q[:, :, i:(i+1), :], ref_k[:, :, i:(i+1), :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead + else: + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map(visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], ref_target_masks, ref_images_count) + x_ref_attn_maps += x_ref_attn_maps_perhead x_ref_attn_maps /= split_num return x_ref_attn_maps @@ -158,7 +196,6 @@ def __init__(self, self.base = 10000 - @lru_cache(maxsize=32) def precompute_freqs_cis_1d(self, pos_indices): freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) @@ -167,7 +204,7 @@ def precompute_freqs_cis_1d(self, pos_indices): freqs = repeat(freqs, "... n -> ... (n r)", r=2) return freqs - def forward(self, x, pos_indices): + def forward(self, qlist, pos_indices, cache_entry = None): """1D RoPE. Args: @@ -176,16 +213,26 @@ def forward(self, x, pos_indices): Returns: query with the same shape as input. """ - freqs_cis = self.precompute_freqs_cis_1d(pos_indices) - - x_ = x.float() - - freqs_cis = freqs_cis.float().to(x.device) - cos, sin = freqs_cis.cos(), freqs_cis.sin() - cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') - x_ = (x_ * cos) + (rotate_half(x_) * sin) - - return x_.type_as(x) + xq= qlist[0] + qlist.clear() + cache = get_cache("multitalk_rope") + freqs_cis= cache.get(cache_entry, None) + if freqs_cis is None: + freqs_cis = cache[cache_entry] = self.precompute_freqs_cis_1d(pos_indices) + cos, sin = freqs_cis.cos().unsqueeze(0).unsqueeze(0), freqs_cis.sin().unsqueeze(0).unsqueeze(0) + # cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + # real * cos - imag * sin + # imag * cos + real * sin + xq_dtype = xq.dtype + xq_out = xq.to(torch.float) + xq = None + xq_rot = rotate_half(xq_out) + xq_out *= cos + xq_rot *= sin + xq_out += xq_rot + del xq_rot + xq_out = xq_out.to(xq_dtype) + return xq_out diff --git a/models/wan/text2video fuse attempt.py b/models/wan/text2video fuse attempt.py deleted file mode 100644 index 8af94583d..000000000 --- a/models/wan/text2video fuse attempt.py +++ /dev/null @@ -1,698 +0,0 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. -import gc -import logging -import math -import os -import random -import sys -import types -from contextlib import contextmanager -from functools import partial -from mmgp import offload -import torch -import torch.nn as nn -import torch.cuda.amp as amp -import torch.distributed as dist -from tqdm import tqdm -from PIL import Image -import torchvision.transforms.functional as TF -import torch.nn.functional as F -from .distributed.fsdp import shard_model -from .modules.model import WanModel -from .modules.t5 import T5EncoderModel -from .modules.vae import WanVAE -from .utils.fm_solvers import (FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, retrieve_timesteps) -from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from wan.modules.posemb_layers import get_rotary_pos_embed -from .utils.vace_preprocessor import VaceVideoProcessor - - -def optimized_scale(positive_flat, negative_flat): - - # Calculate dot production - dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True) - - # Squared norm of uncondition - squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8 - - # st_star = v_cond^T * v_uncond / ||v_uncond||^2 - st_star = dot_product / squared_norm - - return st_star - - -class WanT2V: - - def __init__( - self, - config, - checkpoint_dir, - rank=0, - model_filename = None, - text_encoder_filename = None, - quantizeTransformer = False, - dtype = torch.bfloat16 - ): - self.device = torch.device(f"cuda") - self.config = config - self.rank = rank - self.dtype = dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype - - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=text_encoder_filename, - tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), - shard_fn= None) - - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - - - self.vae = WanVAE( - vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), - device=self.device) - - logging.info(f"Creating WanModel from {model_filename}") - from mmgp import offload - - self.model = offload.fast_load_transformers_model(model_filename, modelClass=WanModel,do_quantize= quantizeTransformer, writable_tensors= False) - # offload.load_model_data(self.model, "recam.ckpt") - # self.model.cpu() - # offload.save_model(self.model, "recam.safetensors") - if self.dtype == torch.float16 and not "fp16" in model_filename: - self.model.to(self.dtype) - # offload.save_model(self.model, "t2v_fp16.safetensors",do_quantize=True) - if self.dtype == torch.float16: - self.vae.model.to(self.dtype) - self.model.eval().requires_grad_(False) - - - self.sample_neg_prompt = config.sample_neg_prompt - - if "Vace" in model_filename: - self.vid_proc = VaceVideoProcessor(downsample=tuple([x * y for x, y in zip(config.vae_stride, self.patch_size)]), - min_area=480*832, - max_area=480*832, - min_fps=config.sample_fps, - max_fps=config.sample_fps, - zero_start=True, - seq_len=32760, - keep_last=True) - - self.adapt_vace_model() - - self.scheduler = FlowUniPCMultistepScheduler() - - def vace_encode_frames(self, frames, ref_images, masks=None, tile_size = 0): - if ref_images is None: - ref_images = [None] * len(frames) - else: - assert len(frames) == len(ref_images) - - if masks is None: - latents = self.vae.encode(frames, tile_size = tile_size) - else: - inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)] - reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)] - inactive = self.vae.encode(inactive, tile_size = tile_size) - reactive = self.vae.encode(reactive, tile_size = tile_size) - latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)] - - cat_latents = [] - for latent, refs in zip(latents, ref_images): - if refs is not None: - if masks is None: - ref_latent = self.vae.encode(refs, tile_size = tile_size) - else: - ref_latent = self.vae.encode(refs, tile_size = tile_size) - ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent] - assert all([x.shape[1] == 1 for x in ref_latent]) - latent = torch.cat([*ref_latent, latent], dim=1) - cat_latents.append(latent) - return cat_latents - - def vace_encode_masks(self, masks, ref_images=None): - if ref_images is None: - ref_images = [None] * len(masks) - else: - assert len(masks) == len(ref_images) - - result_masks = [] - for mask, refs in zip(masks, ref_images): - c, depth, height, width = mask.shape - new_depth = int((depth + 3) // self.vae_stride[0]) - height = 2 * (int(height) // (self.vae_stride[1] * 2)) - width = 2 * (int(width) // (self.vae_stride[2] * 2)) - - # reshape - mask = mask[0, :, :, :] - mask = mask.view( - depth, height, self.vae_stride[1], width, self.vae_stride[1] - ) # depth, height, 8, width, 8 - mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width - mask = mask.reshape( - self.vae_stride[1] * self.vae_stride[2], depth, height, width - ) # 8*8, depth, height, width - - # interpolation - mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode='nearest-exact').squeeze(0) - - if refs is not None: - length = len(refs) - mask_pad = torch.zeros_like(mask[:, :length, :, :]) - mask = torch.cat((mask_pad, mask), dim=1) - result_masks.append(mask) - return result_masks - - def vace_latent(self, z, m): - return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)] - - def prepare_source(self, src_video, src_mask, src_ref_images, total_frames, image_size, device, original_video = False, keep_frames= [], start_frame = 0, pre_src_video = None): - image_sizes = [] - trim_video = len(keep_frames) - - for i, (sub_src_video, sub_src_mask, sub_pre_src_video) in enumerate(zip(src_video, src_mask,pre_src_video)): - prepend_count = 0 if sub_pre_src_video == None else sub_pre_src_video.shape[1] - num_frames = total_frames - prepend_count - if sub_src_mask is not None and sub_src_video is not None: - src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) - # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255]) - # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255]) - src_video[i] = src_video[i].to(device) - src_mask[i] = src_mask[i].to(device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1) - image_sizes.append(src_video[i].shape[2:]) - elif sub_src_video is None: - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), torch.ones((3, num_frames, image_size[0], image_size[1]), device=device)] ,1) - else: - src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device) - src_mask[i] = torch.ones_like(src_video[i], device=device) - image_sizes.append(image_size) - else: - src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video - prepend_count, start_frame = start_frame) - src_video[i] = src_video[i].to(device) - src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device) - if prepend_count > 0: - src_video[i] = torch.cat( [sub_pre_src_video, src_video[i]], dim=1) - src_mask[i] = torch.cat( [torch.zeros_like(sub_pre_src_video), src_mask[i]] ,1) - src_video_shape = src_video[i].shape - if src_video_shape[1] != total_frames: - src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], total_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1) - image_sizes.append(src_video[i].shape[2:]) - for k, keep in enumerate(keep_frames): - if not keep: - src_video[i][:, k:k+1] = 0 - src_mask[i][:, k:k+1] = 1 - - for i, ref_images in enumerate(src_ref_images): - if ref_images is not None: - image_size = image_sizes[i] - for j, ref_img in enumerate(ref_images): - if ref_img is not None: - ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) - if ref_img.shape[-2:] != image_size: - canvas_height, canvas_width = image_size - ref_height, ref_width = ref_img.shape[-2:] - white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1] - scale = min(canvas_height / ref_height, canvas_width / ref_width) - new_height = int(ref_height * scale) - new_width = int(ref_width * scale) - resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode='bilinear', align_corners=False).squeeze(0).unsqueeze(1) - top = (canvas_height - new_height) // 2 - left = (canvas_width - new_width) // 2 - white_canvas[:, :, top:top + new_height, left:left + new_width] = resized_image - ref_img = white_canvas - src_ref_images[i][j] = ref_img.to(device) - return src_video, src_mask, src_ref_images - - def decode_latent(self, zs, ref_images=None, tile_size= 0 ): - if ref_images is None: - ref_images = [None] * len(zs) - else: - assert len(zs) == len(ref_images) - - trimed_zs = [] - for z, refs in zip(zs, ref_images): - if refs is not None: - z = z[:, len(refs):, :, :] - trimed_zs.append(z) - - return self.vae.decode(trimed_zs, tile_size= tile_size) - - def generate_timestep_matrix( - self, - num_frames, - step_template, - base_num_frames, - ar_step=5, - num_pre_ready=0, - casual_block_size=1, - shrink_interval_with_mask=False, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]: - step_matrix, step_index = [], [] - update_mask, valid_interval = [], [] - num_iterations = len(step_template) + 1 - num_frames_block = num_frames // casual_block_size - base_num_frames_block = base_num_frames // casual_block_size - if base_num_frames_block < num_frames_block: - infer_step_num = len(step_template) - gen_block = base_num_frames_block - min_ar_step = infer_step_num / gen_block - assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting" - # print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block) - step_template = torch.cat( - [ - torch.tensor([999], dtype=torch.int64, device=step_template.device), - step_template.long(), - torch.tensor([0], dtype=torch.int64, device=step_template.device), - ] - ) # to handle the counter in row works starting from 1 - pre_row = torch.zeros(num_frames_block, dtype=torch.long) - if num_pre_ready > 0: - pre_row[: num_pre_ready // casual_block_size] = num_iterations - - while torch.all(pre_row >= (num_iterations - 1)) == False: - new_row = torch.zeros(num_frames_block, dtype=torch.long) - for i in range(num_frames_block): - if i == 0 or pre_row[i - 1] >= ( - num_iterations - 1 - ): # the first frame or the last frame is completely denoised - new_row[i] = pre_row[i] + 1 - else: - new_row[i] = new_row[i - 1] - ar_step - new_row = new_row.clamp(0, num_iterations) - - update_mask.append( - (new_row != pre_row) & (new_row != num_iterations) - ) # False: no need to update, True: need to update - step_index.append(new_row) - step_matrix.append(step_template[new_row]) - pre_row = new_row - - # for long video we split into several sequences, base_num_frames is set to the model max length (for training) - terminal_flag = base_num_frames_block - if shrink_interval_with_mask: - idx_sequence = torch.arange(num_frames_block, dtype=torch.int64) - update_mask = update_mask[0] - update_mask_idx = idx_sequence[update_mask] - last_update_idx = update_mask_idx[-1].item() - terminal_flag = last_update_idx + 1 - # for i in range(0, len(update_mask)): - for curr_mask in update_mask: - if terminal_flag < num_frames_block and curr_mask[terminal_flag]: - terminal_flag += 1 - valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag)) - - step_update_mask = torch.stack(update_mask, dim=0) - step_index = torch.stack(step_index, dim=0) - step_matrix = torch.stack(step_matrix, dim=0) - - if casual_block_size > 1: - step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous() - valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval] - - return step_matrix, step_index, step_update_mask, valid_interval - - def generate(self, - input_prompt, - input_frames= None, - input_masks = None, - input_ref_images = None, - source_video=None, - target_camera=None, - context_scale=1.0, - size=(1280, 720), - frame_num=81, - shift=5.0, - sample_solver='unipc', - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True, - callback = None, - enable_RIFLEx = None, - VAE_tile_size = 0, - joint_pass = False, - slg_layers = None, - slg_start = 0.0, - slg_end = 1.0, - cfg_star_switch = True, - cfg_zero_step = 5, - ): - r""" - Generates video frames from text prompt using diffusion process. - - Args: - input_prompt (`str`): - Text prompt for content generation - size (tupele[`int`], *optional*, defaults to (1280,720)): - Controls video resolution, (width,height). - frame_num (`int`, *optional*, defaults to 81): - How many frames to sample from a video. The number should be 4n+1 - shift (`float`, *optional*, defaults to 5.0): - Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. - sampling_steps (`int`, *optional*, defaults to 40): - Number of diffusion sampling steps. Higher values improve quality but slow generation - guide_scale (`float`, *optional*, defaults 5.0): - Classifier-free guidance scale. Controls prompt adherence vs. creativity - n_prompt (`str`, *optional*, defaults to ""): - Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` - seed (`int`, *optional*, defaults to -1): - Random seed for noise generation. If -1, use random seed. - offload_model (`bool`, *optional*, defaults to True): - If True, offloads models to CPU during generation to save VRAM - - Returns: - torch.Tensor: - Generated video frames tensor. Dimensions: (C, N H, W) where: - - C: Color channels (3 for RGB) - - N: Number of frames (81) - - H: Frame height (from size) - - W: Frame width from size) - """ - # preprocess - - if n_prompt == "": - n_prompt = self.sample_neg_prompt - seed = seed if seed >= 0 else random.randint(0, sys.maxsize) - seed_g = torch.Generator(device=self.device) - seed_g.manual_seed(seed) - - frame_num = max(17, frame_num) # must match causal_block_size for value of 5 - frame_num = int( round( (frame_num - 17) / 20)* 20 + 17 ) - num_frames = frame_num - addnoise_condition = 20 - causal_attention = True - fps = 16 - ar_step = 5 - - - - context = self.text_encoder([input_prompt], self.device) - context_null = self.text_encoder([n_prompt], self.device) - if target_camera != None: - size = (source_video.shape[2], source_video.shape[1]) - source_video = source_video.to(dtype=self.dtype , device=self.device) - source_video = source_video.permute(3, 0, 1, 2).div_(127.5).sub_(1.) - source_latents = self.vae.encode([source_video]) #.to(dtype=self.dtype, device=self.device) - del source_video - # Process target camera (recammaster) - from wan.utils.cammmaster_tools import get_camera_embedding - cam_emb = get_camera_embedding(target_camera) - cam_emb = cam_emb.to(dtype=self.dtype, device=self.device) - - if input_frames != None: - # vace context encode - input_frames = [u.to(self.device) for u in input_frames] - input_ref_images = [ None if u == None else [v.to(self.device) for v in u] for u in input_ref_images] - input_masks = [u.to(self.device) for u in input_masks] - - z0 = self.vace_encode_frames(input_frames, input_ref_images, masks=input_masks, tile_size = VAE_tile_size) - m0 = self.vace_encode_masks(input_masks, input_ref_images) - z = self.vace_latent(z0, m0) - - target_shape = list(z0[0].shape) - target_shape[0] = int(target_shape[0] / 2) - else: - F = frame_num - target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2]) - - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1]) - - context = [u.to(self.dtype) for u in context] - context_null = [u.to(self.dtype) for u in context_null] - - noise = [ torch.randn( *target_shape, dtype=torch.float32, device=self.device, generator=seed_g) ] - - # evaluation mode - - # if sample_solver == 'unipc': - # sample_scheduler = FlowUniPCMultistepScheduler( - # num_train_timesteps=self.num_train_timesteps, - # shift=1, - # use_dynamic_shifting=False) - # sample_scheduler.set_timesteps( - # sampling_steps, device=self.device, shift=shift) - # timesteps = sample_scheduler.timesteps - # elif sample_solver == 'dpm++': - # sample_scheduler = FlowDPMSolverMultistepScheduler( - # num_train_timesteps=self.num_train_timesteps, - # shift=1, - # use_dynamic_shifting=False) - # sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - # timesteps, _ = retrieve_timesteps( - # sample_scheduler, - # device=self.device, - # sigmas=sampling_sigmas) - # else: - # raise NotImplementedError("Unsupported solver.") - - # sample videos - latents = noise - del noise - batch_size =len(latents) - if target_camera != None: - shape = list(latents[0].shape[1:]) - shape[0] *= 2 - freqs = get_rotary_pos_embed(shape, enable_RIFLEx= False) - else: - freqs = get_rotary_pos_embed(latents[0].shape[1:], enable_RIFLEx= enable_RIFLEx) - # arg_c = {'context': context, 'freqs': freqs, 'pipeline': self, 'callback': callback} - # arg_null = {'context': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} - # arg_both = {'context': context, 'context2': context_null, 'freqs': freqs, 'pipeline': self, 'callback': callback} - - i2v_extra_kwrags = {} - - if target_camera != None: - recam_dict = {'cam_emb': cam_emb} - i2v_extra_kwrags.update(recam_dict) - - if input_frames != None: - vace_dict = {'vace_context' : z, 'vace_context_scale' : context_scale} - i2v_extra_kwrags.update(vace_dict) - - - latent_length = (num_frames - 1) // 4 + 1 - latent_height = height // 8 - latent_width = width // 8 - if ar_step == 0: - causal_block_size = 1 - fps_embeds = [fps] #* prompt_embeds[0].shape[0] - fps_embeds = [0 if i == 16 else 1 for i in fps_embeds] - - self.scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) - init_timesteps = self.scheduler.timesteps - base_num_frames_iter = latent_length - latent_shape = [16, base_num_frames_iter, latent_height, latent_width] - - prefix_video = None - predix_video_latent_length = 0 - - if prefix_video is not None: - latents[0][:, :predix_video_latent_length] = prefix_video[0].to(torch.float32) - step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix( - base_num_frames_iter, - init_timesteps, - base_num_frames_iter, - ar_step, - predix_video_latent_length, - causal_block_size, - ) - sample_schedulers = [] - for _ in range(base_num_frames_iter): - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=1000, shift=1, use_dynamic_shifting=False - ) - sample_scheduler.set_timesteps(sampling_steps, device=self.device, shift=shift) - sample_schedulers.append(sample_scheduler) - sample_schedulers_counter = [0] * base_num_frames_iter - - updated_num_steps= len(step_matrix) - - if callback != None: - callback(-1, None, True, override_num_inference_steps = updated_num_steps) - if self.model.enable_teacache: - self.model.compute_teacache_threshold(self.model.teacache_start_step, timesteps, self.model.teacache_multiplier) - # if callback != None: - # callback(-1, None, True) - - for i, timestep_i in enumerate(tqdm(step_matrix)): - update_mask_i = step_update_mask[i] - valid_interval_i = valid_interval[i] - valid_interval_start, valid_interval_end = valid_interval_i - timestep = timestep_i[None, valid_interval_start:valid_interval_end].clone() - latent_model_input = [latents[0][:, valid_interval_start:valid_interval_end, :, :].clone()] - if addnoise_condition > 0 and valid_interval_start < predix_video_latent_length: - noise_factor = 0.001 * addnoise_condition - timestep_for_noised_condition = addnoise_condition - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] = ( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - * (1.0 - noise_factor) - + torch.randn_like( - latent_model_input[0][:, valid_interval_start:predix_video_latent_length] - ) - * noise_factor - ) - timestep[:, valid_interval_start:predix_video_latent_length] = timestep_for_noised_condition - kwrags = { - "x" : torch.stack([latent_model_input[0]]), - "t" : timestep, - "freqs" :freqs, - "fps" : fps_embeds, - "causal_block_size" : causal_block_size, - "causal_attention" : causal_attention, - "callback" : callback, - "pipeline" : self, - "current_step" : i, - } - kwrags.update(i2v_extra_kwrags) - - if not self.do_classifier_free_guidance: - noise_pred = self.model( - context=context, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred= noise_pred.to(torch.float32) - else: - if joint_pass: - noise_pred_cond, noise_pred_uncond = self.model( - context=context, - context2=context_null, - **kwrags, - ) - if self._interrupt: - return None - else: - noise_pred_cond = self.model( - context=context, - **kwrags, - )[0] - if self._interrupt: - return None - noise_pred_uncond = self.model( - context=context_null, - )[0] - if self._interrupt: - return None - noise_pred_cond= noise_pred_cond.to(torch.float32) - noise_pred_uncond= noise_pred_uncond.to(torch.float32) - noise_pred = noise_pred_uncond + guide_scale * (noise_pred_cond - noise_pred_uncond) - del noise_pred_cond, noise_pred_uncond - for idx in range(valid_interval_start, valid_interval_end): - if update_mask_i[idx].item(): - latents[0][:, idx] = sample_schedulers[idx].step( - noise_pred[:, idx - valid_interval_start], - timestep_i[idx], - latents[0][:, idx], - return_dict=False, - generator=seed_g, - )[0] - sample_schedulers_counter[idx] += 1 - if callback is not None: - callback(i, latents[0].squeeze(0), False) - - # for i, t in enumerate(tqdm(timesteps)): - # if target_camera != None: - # latent_model_input = [torch.cat([u,v], dim=1) for u,v in zip(latents,source_latents )] - # else: - # latent_model_input = latents - # slg_layers_local = None - # if int(slg_start * sampling_steps) <= i < int(slg_end * sampling_steps): - # slg_layers_local = slg_layers - # timestep = [t] - # offload.set_step_no_for_lora(self.model, i) - # timestep = torch.stack(timestep) - - # if joint_pass: - # noise_pred_cond, noise_pred_uncond = self.model( - # latent_model_input, t=timestep, current_step=i, slg_layers=slg_layers_local, **arg_both) - # if self._interrupt: - # return None - # else: - # noise_pred_cond = self.model( - # latent_model_input, t=timestep,current_step=i, is_uncond = False, **arg_c)[0] - # if self._interrupt: - # return None - # noise_pred_uncond = self.model( - # latent_model_input, t=timestep,current_step=i, is_uncond = True, slg_layers=slg_layers_local, **arg_null)[0] - # if self._interrupt: - # return None - - # # del latent_model_input - - # # CFG Zero *. Thanks to https://github.com/WeichenFan/CFG-Zero-star/ - # noise_pred_text = noise_pred_cond - # if cfg_star_switch: - # positive_flat = noise_pred_text.view(batch_size, -1) - # negative_flat = noise_pred_uncond.view(batch_size, -1) - - # alpha = optimized_scale(positive_flat,negative_flat) - # alpha = alpha.view(batch_size, 1, 1, 1) - - # if (i <= cfg_zero_step): - # noise_pred = noise_pred_text*0. # it would be faster not to compute noise_pred... - # else: - # noise_pred_uncond *= alpha - # noise_pred = noise_pred_uncond + guide_scale * (noise_pred_text - noise_pred_uncond) - # del noise_pred_uncond - - # temp_x0 = sample_scheduler.step( - # noise_pred[:, :target_shape[1]].unsqueeze(0), - # t, - # latents[0].unsqueeze(0), - # return_dict=False, - # generator=seed_g)[0] - # latents = [temp_x0.squeeze(0)] - # del temp_x0 - - # if callback is not None: - # callback(i, latents[0], False) - - x0 = latents - - if input_frames == None: - videos = self.vae.decode(x0, VAE_tile_size) - else: - videos = self.decode_latent(x0, input_ref_images, VAE_tile_size) - - del latents - del sample_scheduler - - return videos[0] if self.rank == 0 else None - - def adapt_vace_model(self): - model = self.model - modules_dict= { k: m for k, m in model.named_modules()} - for model_layer, vace_layer in model.vace_layers_mapping.items(): - module = modules_dict[f"vace_blocks.{vace_layer}"] - target = modules_dict[f"blocks.{model_layer}"] - setattr(target, "vace", module ) - delattr(model, "vace_blocks") - - \ No newline at end of file diff --git a/models/wan/wan_handler.py b/models/wan/wan_handler.py index c3ad01294..95a43ee18 100644 --- a/models/wan/wan_handler.py +++ b/models/wan/wan_handler.py @@ -1,8 +1,12 @@ import torch import numpy as np +import gradio as gr def test_class_i2v(base_model_type): - return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk" ] + return base_model_type in ["i2v", "i2v_2_2", "fun_inp_1.3B", "fun_inp", "flf2v_720p", "fantasy", "multitalk", "infinitetalk", "i2v_2_2_multitalk", "animate" ] + +def text_oneframe_overlap(base_model_type): + return test_class_i2v(base_model_type) and not (test_multitalk(base_model_type) or base_model_type in ["animate"]) or test_wan_5B(base_model_type) def test_class_1_3B(base_model_type): return base_model_type in [ "vace_1.3B", "t2v_1.3B", "recam_1.3B","phantom_1.3B","fun_inp_1.3B"] @@ -13,6 +17,11 @@ def test_multitalk(base_model_type): def test_standin(base_model_type): return base_model_type in ["standin", "vace_standin_14B"] +def test_lynx(base_model_type): + return base_model_type in ["lynx_lite", "vace_lynx_lite_14B", "lynx", "vace_lynx_14B"] + +def test_wan_5B(base_model_type): + return base_model_type in ["ti2v_2_2", "lucy_edit"] class family_handler(): @staticmethod @@ -32,7 +41,7 @@ def set_cache_parameters(cache_type, base_model_type, model_def, inputs, skip_st def_mag_ratios = [1.00124, 1.00155, 0.99822, 0.99851, 0.99696, 0.99687, 0.99703, 0.99732, 0.9966, 0.99679, 0.99602, 0.99658, 0.99578, 0.99664, 0.99484, 0.9949, 0.99633, 0.996, 0.99659, 0.99683, 0.99534, 0.99549, 0.99584, 0.99577, 0.99681, 0.99694, 0.99563, 0.99554, 0.9944, 0.99473, 0.99594, 0.9964, 0.99466, 0.99461, 0.99453, 0.99481, 0.99389, 0.99365, 0.99391, 0.99406, 0.99354, 0.99361, 0.99283, 0.99278, 0.99268, 0.99263, 0.99057, 0.99091, 0.99125, 0.99126, 0.65523, 0.65252, 0.98808, 0.98852, 0.98765, 0.98736, 0.9851, 0.98535, 0.98311, 0.98339, 0.9805, 0.9806, 0.97776, 0.97771, 0.97278, 0.97286, 0.96731, 0.96728, 0.95857, 0.95855, 0.94385, 0.94385, 0.92118, 0.921, 0.88108, 0.88076, 0.80263, 0.80181] elif base_model_type in ["i2v_2_2"]: def_mag_ratios = [0.99191, 0.99144, 0.99356, 0.99337, 0.99326, 0.99285, 0.99251, 0.99264, 0.99393, 0.99366, 0.9943, 0.9943, 0.99276, 0.99288, 0.99389, 0.99393, 0.99274, 0.99289, 0.99316, 0.9931, 0.99379, 0.99377, 0.99268, 0.99271, 0.99222, 0.99227, 0.99175, 0.9916, 0.91076, 0.91046, 0.98931, 0.98933, 0.99087, 0.99088, 0.98852, 0.98855, 0.98895, 0.98896, 0.98806, 0.98808, 0.9871, 0.98711, 0.98613, 0.98618, 0.98434, 0.98435, 0.983, 0.98307, 0.98185, 0.98187, 0.98131, 0.98131, 0.9783, 0.97835, 0.97619, 0.9762, 0.97264, 0.9727, 0.97088, 0.97098, 0.96568, 0.9658, 0.96045, 0.96055, 0.95322, 0.95335, 0.94579, 0.94594, 0.93297, 0.93311, 0.91699, 0.9172, 0.89174, 0.89202, 0.8541, 0.85446, 0.79823, 0.79902] - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): if inputs.get("image_start", None) is not None and inputs.get("video_source", None) is not None : # t2v def_mag_ratios = [0.99505, 0.99389, 0.99441, 0.9957, 0.99558, 0.99551, 0.99499, 0.9945, 0.99534, 0.99548, 0.99468, 0.9946, 0.99463, 0.99458, 0.9946, 0.99453, 0.99408, 0.99404, 0.9945, 0.99441, 0.99409, 0.99398, 0.99403, 0.99397, 0.99382, 0.99377, 0.99349, 0.99343, 0.99377, 0.99378, 0.9933, 0.99328, 0.99303, 0.99301, 0.99217, 0.99216, 0.992, 0.99201, 0.99201, 0.99202, 0.99133, 0.99132, 0.99112, 0.9911, 0.99155, 0.99155, 0.98958, 0.98957, 0.98959, 0.98958, 0.98838, 0.98835, 0.98826, 0.98825, 0.9883, 0.98828, 0.98711, 0.98709, 0.98562, 0.98561, 0.98511, 0.9851, 0.98414, 0.98412, 0.98284, 0.98282, 0.98104, 0.98101, 0.97981, 0.97979, 0.97849, 0.97849, 0.97557, 0.97554, 0.97398, 0.97395, 0.97171, 0.97166, 0.96917, 0.96913, 0.96511, 0.96507, 0.96263, 0.96257, 0.95839, 0.95835, 0.95483, 0.95475, 0.94942, 0.94936, 0.9468, 0.94678, 0.94583, 0.94594, 0.94843, 0.94872, 0.96949, 0.97015] else: # i2v @@ -76,14 +85,17 @@ def query_model_def(base_model_type, model_def): extra_model_def["i2v_class"] = i2v extra_model_def["multitalk_class"] = test_multitalk(base_model_type) extra_model_def["standin_class"] = test_standin(base_model_type) - vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B"] + extra_model_def["lynx_class"] = test_lynx(base_model_type) + vace_class = base_model_type in ["vace_14B", "vace_1.3B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B"] extra_model_def["vace_class"] = vace_class - if test_multitalk(base_model_type): + if base_model_type in ["animate"]: + fps = 30 + elif test_multitalk(base_model_type): fps = 25 elif base_model_type in ["fantasy"]: fps = 23 - elif base_model_type in ["ti2v_2_2"]: + elif test_wan_5B(base_model_type): fps = 24 else: fps = 16 @@ -96,14 +108,14 @@ def query_model_def(base_model_type, model_def): extra_model_def.update({ "frames_minimum" : frames_minimum, "frames_steps" : frames_steps, - "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy"] or test_class_i2v(base_model_type) or vace_class, #"ti2v_2_2", + "sliding_window" : base_model_type in ["multitalk", "infinitetalk", "t2v", "fantasy", "animate"] or test_class_i2v(base_model_type) or test_wan_5B(base_model_type) or vace_class, #"ti2v_2_2", "multiple_submodels" : multiple_submodels, "guidance_max_phases" : 3, "skip_layer_guidance" : True, "cfg_zero" : True, "cfg_star" : True, "adaptive_projected_guidance" : True, - "tea_cache" : not (base_model_type in ["i2v_2_2", "ti2v_2_2" ] or multiple_submodels), + "tea_cache" : not (base_model_type in ["i2v_2_2"] or test_wan_5B(base_model_type) or multiple_submodels), "mag_cache" : True, "keep_frames_video_guide_not_supported": base_model_type in ["infinitetalk"], "sample_solvers":[ @@ -112,18 +124,238 @@ def query_model_def(base_model_type, model_def): ("dpm++", "dpm++"), ("flowmatch causvid", "causvid"), ] }) + + + if base_model_type in ["t2v"]: + extra_model_def["guide_custom_choices"] = { + "choices":[("Use Text Prompt Only", ""), + ("Video to Video guided by Text Prompt", "GUV"), + ("Video to Video guided by Text Prompt and Restricted to the Area of the Video Mask", "GVA")], + "default": "", + "letters_filter": "GUVA", + "label": "Video to Video" + } + + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A"], + "visible": False + } + extra_model_def["v2i_switch_supported"] = True + + if base_model_type in ["infinitetalk"]: extra_model_def["no_background_removal"] = True + extra_model_def["all_image_refs_are_background_ref"] = True + extra_model_def["guide_custom_choices"] = { + "choices":[ + ("Images to Video, each Reference Image will start a new shot with a new Sliding Window", "KI"), + ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window", "RUV"), + ("Video to Video, amount of motion transferred depends on Denoising Strength", "GUV"), + ], + "default": "KI", + "letters_filter": "RGUVKI", + "label": "Video to Video", + "scale": 3, + "show_label" : False, + } + + extra_model_def["custom_video_selection"] = { + "choices":[ + ("Smooth Transitions", ""), + ("Sharp Transitions", "0"), + ], + "trigger": "", + "label": "Custom Process", + "letters_filter": "0", + "show_label" : False, + "scale": 1, + } + + # extra_model_def["at_least_one_image_ref_needed"] = True + if base_model_type in ["lucy_edit"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["guide_preprocessing"] = { + "selection": ["UV"], + "labels" : { "UV": "Control Video"}, + "visible": False, + } + + if base_model_type in ["animate"]: + extra_model_def["guide_custom_choices"] = { + "choices":[ + ("Animate Person in Reference Image using Motion of Whole Control Video", "PVBKI"), + ("Animate Person in Reference Image using Motion of Targeted Person in Control Video", "PVBXAKI"), + ("Replace Person in Control Video by Person in Ref Image", "PVBAIH#"), + ("Replace Person in Control Video by Person in Ref Image. See Through Mask", "PVBAI#"), + ], + "default": "PVBKI", + "letters_filter": "PVBXAKIH#", + "label": "Type of Process", + "scale": 3, + "show_label" : False, + } + + extra_model_def["custom_video_selection"] = { + "choices":[ + ("None", ""), + ("Apply Relighting", "1"), + ], + "trigger": "#", + "label": "Custom Process", + "type": "checkbox", + "letters_filter": "1", + "show_label" : False, + "scale": 1, + } + + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A", "XA"], + "visible": False + } + + extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["extract_guide_from_window_start"] = True + extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People (Animate Mode Only)" + extra_model_def["background_ref_outpainted"] = False + extra_model_def["return_image_refs_tensor"] = True + extra_model_def["guide_inpaint_color"] = 0 + + + + if vace_class: + extra_model_def["guide_preprocessing"] = { + "selection": ["", "UV", "PV", "DV", "SV", "LV", "CV", "MV", "V", "PDV", "PSV", "PLV" , "DSV", "DLV", "SLV"], + "labels" : { "V": "Use Vace raw format"} + } + extra_model_def["mask_preprocessing"] = { + "selection": ["", "A", "NA", "XA", "XNA", "YA", "YNA", "WA", "WNA", "ZA", "ZNA"], + } + + extra_model_def["image_ref_choices"] = { + "choices": [("None", ""), + ("People / Objects", "I"), + ("Landscape followed by People / Objects (if any)", "KI"), + ("Positioned Frames followed by People / Objects (if any)", "FI"), + ], + "letters_filter": "KFI", + } + + extra_model_def["background_removal_label"]= "Remove Backgrounds behind People / Objects, keep it for Landscape or Positioned Frames" + extra_model_def["video_guide_outpainting"] = [0,1] + extra_model_def["pad_guide_video"] = True + extra_model_def["guide_inpaint_color"] = 127.5 + extra_model_def["forced_guide_mask_inputs"] = True + extra_model_def["return_image_refs_tensor"] = True + extra_model_def["v2i_switch_supported"] = True + if base_model_type in ["vace_lynx_14B"]: + extra_model_def["set_video_prompt_type"]="Q" + extra_model_def["control_net_weight_alt_name"] = "Lynx" + extra_model_def["image_ref_choices"] = { "choices": [("None", ""),("Person Face", "I")], "letters_filter": "I"} + extra_model_def["no_background_removal"] = True + extra_model_def["fit_into_canvas_image_refs"] = 0 + + + + if base_model_type in ["standin"]: + extra_model_def["v2i_switch_supported"] = True + + if base_model_type in ["lynx_lite", "lynx"]: + extra_model_def["fit_into_canvas_image_refs"] = 0 + extra_model_def["guide_custom_choices"] = { + "choices":[("Use Reference Image which is a Person Face", ""), + ("Video to Video guided by Text Prompt & Reference Image", "GUV"), + ("Video to Video on the Area of the Video Mask", "GVA")], + "default": "", + "letters_filter": "GUVA", + "label": "Video to Video", + "show_label" : False, + } + + extra_model_def["mask_preprocessing"] = { + "selection":[ "", "A"], + "visible": False + } + + extra_model_def["image_ref_choices"] = { + "choices": [ + ("No Reference Image", ""), + ("Reference Image is a Person Face", "I"), + ], + "visible": False, + "letters_filter":"I", + } + + extra_model_def["set_video_prompt_type"]= "Q" + extra_model_def["no_background_removal"] = True + extra_model_def["v2i_switch_supported"] = True + extra_model_def["control_net_weight_alt_name"] = "Lynx" + + + if base_model_type in ["phantom_1.3B", "phantom_14B"]: + extra_model_def["image_ref_choices"] = { + "choices": [("Reference Image", "I")], + "letters_filter":"I", + "visible": False, + } + + if base_model_type in ["recam_1.3B"]: + extra_model_def["keep_frames_video_guide_not_supported"] = True + extra_model_def["model_modes"] = { + "choices": [ + ("Pan Right", 1), + ("Pan Left", 2), + ("Tilt Up", 3), + ("Tilt Down", 4), + ("Zoom In", 5), + ("Zoom Out", 6), + ("Translate Up (with rotation)", 7), + ("Translate Down (with rotation)", 8), + ("Arc Left (with rotation)", 9), + ("Arc Right (with rotation)", 10), + ], + "default": 1, + "label" : "Camera Movement Type" + } + extra_model_def["guide_preprocessing"] = { + "selection": ["UV"], + "labels" : { "UV": "Control Video"}, + "visible" : False, + } + + if vace_class or base_model_type in ["animate"]: + image_prompt_types_allowed = "TVL" + elif base_model_type in ["infinitetalk"]: + image_prompt_types_allowed = "TSVL" + elif base_model_type in ["ti2v_2_2"]: + image_prompt_types_allowed = "TSVL" + elif base_model_type in ["lucy_edit"]: + image_prompt_types_allowed = "TVL" + elif test_multitalk(base_model_type) or base_model_type in ["fantasy"]: + image_prompt_types_allowed = "SVL" + elif i2v: + image_prompt_types_allowed = "SEVL" + else: + image_prompt_types_allowed = "" + extra_model_def["image_prompt_types_allowed"] = image_prompt_types_allowed + + if text_oneframe_overlap(base_model_type): + extra_model_def["sliding_window_defaults"] = { "overlap_min" : 1, "overlap_max" : 1, "overlap_step": 0, "overlap_default": 1} + + # if base_model_type in ["phantom_1.3B", "phantom_14B"]: + # extra_model_def["one_image_ref_needed"] = True + return extra_model_def @staticmethod def query_supported_types(): - return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", - "t2v_1.3B", "standin", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", - "recam_1.3B", - "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] + return ["multitalk", "infinitetalk", "fantasy", "vace_14B", "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_14B", + "t2v_1.3B", "standin", "lynx_lite", "lynx", "t2v", "vace_1.3B", "phantom_1.3B", "phantom_14B", + "recam_1.3B", "animate", + "i2v", "i2v_2_2", "i2v_2_2_multitalk", "ti2v_2_2", "lucy_edit", "flf2v_720p", "fun_inp_1.3B", "fun_inp"] @staticmethod @@ -132,11 +364,13 @@ def query_family_maps(): models_eqv_map = { "flf2v_720p" : "i2v", "t2v_1.3B" : "t2v", + "vace_standin_14B" : "vace_14B", + "vace_lynx_14B" : "vace_14B", } models_comp_map = { - "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B"], - "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin"], + "vace_14B" : [ "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B"], + "t2v" : [ "vace_14B", "vace_1.3B" "vace_multitalk_14B", "vace_standin_14B", "vace_lynx_lite_14B", "vace_lynx_14B", "t2v_1.3B", "phantom_1.3B","phantom_14B", "standin", "lynx_lite", "lynx"], "i2v" : [ "fantasy", "multitalk", "flf2v_720p" ], "i2v_2_2" : ["i2v_2_2_multitalk"], "fantasy": ["multitalk"], @@ -153,11 +387,12 @@ def query_family_infos(): @staticmethod def get_vae_block_size(base_model_type): - return 32 if base_model_type == "ti2v_2_2" else 16 + return 32 if test_wan_5B(base_model_type) else 16 @staticmethod def get_rgb_factors(base_model_type ): from shared.RGB_factors import get_rgb_factors + if test_wan_5B(base_model_type): base_model_type = "ti2v_2_2" latent_rgb_factors, latent_rgb_factors_bias = get_rgb_factors("wan", base_model_type) return latent_rgb_factors, latent_rgb_factors_bias @@ -171,7 +406,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder "fileList" : [ [ "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "sentencepiece.bpe.model", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json"], ["special_tokens_map.json", "spiece.model", "tokenizer.json", "tokenizer_config.json"] + computeList(text_encoder_filename) , ["Wan2.1_VAE.safetensors", "fantasy_proj_model.safetensors" ] + computeList(model_filename) ] }] - if base_model_type == "ti2v_2_2": + if test_wan_5B(base_model_type): download_def += [ { "repoId" : "DeepBeepMeep/Wan2.2", "sourceFolderList" : [""], @@ -182,7 +417,7 @@ def query_model_files(computeList, base_model_type, model_filename, text_encoder @staticmethod - def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False): + def load_model(model_filename, model_type, base_model_type, model_def, quantizeTransformer = False, text_encoder_quantization = None, dtype = torch.bfloat16, VAE_dtype = torch.float32, mixed_precision_transformer = False, save_quantized= False, submodel_no_list = None): from .configs import WAN_CONFIGS if test_class_i2v(base_model_type): @@ -195,6 +430,7 @@ def load_model(model_filename, model_type, base_model_type, model_def, quantizeT config=cfg, checkpoint_dir="ckpts", model_filename=model_filename, + submodel_no_list = submodel_no_list, model_type = model_type, model_def = model_def, base_model_type=base_model_type, @@ -233,17 +469,59 @@ def fix_settings(base_model_type, settings_version, model_def, ui_defaults): ui_defaults["audio_guidance_scale"]= 1 video_prompt_type = ui_defaults.get("video_prompt_type", "") if "I" in video_prompt_type: - video_prompt_type = video_prompt_type.replace("KI", "QKI") + video_prompt_type = video_prompt_type.replace("KI", "0KI") ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.28: + if base_model_type in "infinitetalk": + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "U" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("U", "RU") + ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.31: + if base_model_type in ["recam_1.3B"]: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if not "V" in video_prompt_type: + video_prompt_type += "UV" + ui_defaults["video_prompt_type"] = video_prompt_type + ui_defaults["image_prompt_type"] = "" + + if text_oneframe_overlap(base_model_type): + ui_defaults["sliding_window_overlap"] = 1 + + if settings_version < 2.32: + image_prompt_type = ui_defaults.get("image_prompt_type", "") + if test_class_i2v(base_model_type) and len(image_prompt_type) == 0: + ui_defaults["image_prompt_type"] = "S" + + + if settings_version < 2.37: + if base_model_type in ["animate"]: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "1" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("1", "#1") + ui_defaults["video_prompt_type"] = video_prompt_type + + if settings_version < 2.38: + if base_model_type in ["infinitetalk"]: + video_prompt_type = ui_defaults.get("video_prompt_type", "") + if "Q" in video_prompt_type: + video_prompt_type = video_prompt_type.replace("Q", "0") + ui_defaults["video_prompt_type"] = video_prompt_type + @staticmethod def update_default_settings(base_model_type, model_def, ui_defaults): ui_defaults.update({ "sample_solver": "unipc", }) + if test_class_i2v(base_model_type) and "S" in model_def["image_prompt_types_allowed"]: + ui_defaults["image_prompt_type"] = "S" + if base_model_type in ["fantasy"]: ui_defaults.update({ "audio_guidance_scale": 5.0, - "sliding_window_size": 1, + "sliding_window_overlap" : 1, }) elif base_model_type in ["multitalk"]: @@ -260,8 +538,9 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "guidance_scale": 5.0, "flow_shift": 7, # 11 for 720p "sliding_window_overlap" : 9, + "sliding_window_size": 81, "sample_solver" : "euler", - "video_prompt_type": "QKI", + "video_prompt_type": "0KI", "remove_background_images_ref" : 0, "adaptive_switch" : 1, }) @@ -272,8 +551,19 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "flow_shift": 7, # 11 for 720p "sliding_window_overlap" : 9, "video_prompt_type": "I", - "remove_background_images_ref" : 1, + "remove_background_images_ref" : 1 , + }) + + elif base_model_type in ["lynx_lite", "lynx"]: + ui_defaults.update({ + "guidance_scale": 5.0, + "flow_shift": 7, # 11 for 720p + "sliding_window_overlap" : 9, + "video_prompt_type": "I", + "denoising_strength": 0.8, + "remove_background_images_ref" : 0, }) + elif base_model_type in ["phantom_1.3B", "phantom_14B"]: ui_defaults.update({ "guidance_scale": 7.5, @@ -293,6 +583,21 @@ def update_default_settings(base_model_type, model_def, ui_defaults): "image_prompt_type": "T", }) + if base_model_type in ["recam_1.3B", "lucy_edit"]: + ui_defaults.update({ + "video_prompt_type": "UV", + }) + elif base_model_type in ["animate"]: + ui_defaults.update({ + "video_prompt_type": "PVBKI", + "mask_expand": 20, + "audio_prompt_type": "R", + }) + + if text_oneframe_overlap(base_model_type): + ui_defaults["sliding_window_overlap"] = 1 + ui_defaults["color_correction_strength"]= 0 + if test_multitalk(base_model_type): ui_defaults["audio_guidance_scale"] = 4 @@ -309,3 +614,11 @@ def validate_generative_settings(base_model_type, model_def, inputs): if ("V" in image_prompt_type or "L" in image_prompt_type) and image_refs is None: video_prompt_type = video_prompt_type.replace("I", "").replace("K","") inputs["video_prompt_type"] = video_prompt_type + + + if base_model_type in ["vace_standin_14B"]: + image_refs = inputs["image_refs"] + video_prompt_type = inputs["video_prompt_type"] + if image_refs is not None and len(image_refs) == 1 and "K" in video_prompt_type: + gr.Info("Warning, Ref Image for Standin Missing: if 'Landscape and then People or Objects' is selected beside the Landscape Image Ref there should be another Image Ref that contains a Face.") + diff --git a/preprocessing/arc/face_encoder.py b/preprocessing/arc/face_encoder.py new file mode 100644 index 000000000..0fb916219 --- /dev/null +++ b/preprocessing/arc/face_encoder.py @@ -0,0 +1,97 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torchvision.transforms as T + +import numpy as np + +from insightface.utils import face_align +from insightface.app import FaceAnalysis +from facexlib.recognition import init_recognition_model + + +__all__ = [ + "FaceEncoderArcFace", + "get_landmarks_from_image", +] + + +detector = None + +def get_landmarks_from_image(image): + """ + Detect landmarks with insightface. + + Args: + image (np.ndarray or PIL.Image): + The input image in RGB format. + + Returns: + 5 2D keypoints, only one face will be returned. + """ + global detector + if detector is None: + detector = FaceAnalysis() + detector.prepare(ctx_id=0, det_size=(640, 640)) + + in_image = np.array(image).copy() + + faces = detector.get(in_image) + if len(faces) == 0: + raise ValueError("No face detected in the image") + + # Get the largest face + face = max(faces, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) + + # Return the 5 keypoints directly + keypoints = face.kps # 5 x 2 + + return keypoints + + +class FaceEncoderArcFace(): + """ Official ArcFace, no_grad-only """ + + def __repr__(self): + return "ArcFace" + + + def init_encoder_model(self, device, eval_mode=True): + self.device = device + self.encoder_model = init_recognition_model('arcface', device=device) + + if eval_mode: + self.encoder_model.eval() + + + @torch.no_grad() + def input_preprocessing(self, in_image, landmarks, image_size=112): + assert landmarks is not None, "landmarks are not provided!" + + in_image = np.array(in_image) + landmark = np.array(landmarks) + + face_aligned = face_align.norm_crop(in_image, landmark=landmark, image_size=image_size) + + image_transform = T.Compose([ + T.ToTensor(), + T.Normalize([0.5], [0.5]), + ]) + face_aligned = image_transform(face_aligned).unsqueeze(0).to(self.device) + + return face_aligned + + + @torch.no_grad() + def __call__(self, in_image, need_proc=False, landmarks=None, image_size=112): + + if need_proc: + in_image = self.input_preprocessing(in_image, landmarks, image_size) + else: + assert isinstance(in_image, torch.Tensor) + + in_image = in_image[:, [2, 1, 0], :, :].contiguous() + image_embeds = self.encoder_model(in_image) # [B, 512], normalized + + return image_embeds \ No newline at end of file diff --git a/preprocessing/arc/face_utils.py b/preprocessing/arc/face_utils.py new file mode 100644 index 000000000..41a454546 --- /dev/null +++ b/preprocessing/arc/face_utils.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 Insightface Team +# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates + +# This file has been modified by Bytedance Ltd. and/or its affiliates on September 15, 2025. +# SPDX-License-Identifier: Apache-2.0 + +# Original file (insightface) was released under MIT License: + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import numpy as np +import cv2 +from PIL import Image + +def estimate_norm(lmk, image_size=112, arcface_dst=None): + from skimage import transform as trans + assert lmk.shape == (5, 2) + assert image_size%112==0 or image_size%128==0 + if image_size%112==0: + ratio = float(image_size)/112.0 + diff_x = 0 + else: + ratio = float(image_size)/128.0 + diff_x = 8.0*ratio + dst = arcface_dst * ratio + dst[:,0] += diff_x + tform = trans.SimilarityTransform() + tform.estimate(lmk, dst) + M = tform.params[0:2, :] + return M + +def get_arcface_dst(extend_face_crop=False, extend_ratio=0.8): + arcface_dst = np.array( + [[38.2946, 51.6963], [73.5318, 51.5014], [56.0252, 71.7366], + [41.5493, 92.3655], [70.7299, 92.2041]], dtype=np.float32) + if extend_face_crop: + arcface_dst[:,1] = arcface_dst[:,1] + 10 + arcface_dst = (arcface_dst - 112/2) * extend_ratio + 112/2 + return arcface_dst + +def align_face(image_pil, face_kpts, extend_face_crop=False, extend_ratio=0.8, face_size=112): + arcface_dst = get_arcface_dst(extend_face_crop, extend_ratio) + M = estimate_norm(face_kpts, face_size, arcface_dst) + image_cv2 = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) + face_image_cv2 = cv2.warpAffine(image_cv2, M, (face_size, face_size), borderValue=0.0) + face_image = Image.fromarray(cv2.cvtColor(face_image_cv2, cv2.COLOR_BGR2RGB)) + return face_image \ No newline at end of file diff --git a/preprocessing/extract_vocals.py b/preprocessing/extract_vocals.py new file mode 100644 index 000000000..656402644 --- /dev/null +++ b/preprocessing/extract_vocals.py @@ -0,0 +1,69 @@ +from pathlib import Path +import os, tempfile +import numpy as np +import soundfile as sf +import librosa +import torch +import gc + +from audio_separator.separator import Separator + +def get_vocals(src_path: str, dst_path: str, min_seconds: float = 8) -> str: + """ + If the source audio is shorter than `min_seconds`, pad with trailing silence + in a temporary file, then run separation and save only the vocals to dst_path. + Returns the full path to the vocals file. + """ + + default_device = torch.get_default_device() + torch.set_default_device('cpu') + + dst = Path(dst_path) + dst.parent.mkdir(parents=True, exist_ok=True) + + # Quick duration check + duration = librosa.get_duration(path=src_path) + + use_path = src_path + temp_path = None + try: + if duration < min_seconds: + # Load (resample) and pad in memory + y, sr = librosa.load(src_path, sr=None, mono=False) + if y.ndim == 1: # ensure shape (channels, samples) + y = y[np.newaxis, :] + target_len = int(min_seconds * sr) + pad = max(0, target_len - y.shape[1]) + if pad: + y = np.pad(y, ((0, 0), (0, pad)), mode="constant") + + # Write a temp WAV for the separator + fd, temp_path = tempfile.mkstemp(suffix=".wav") + os.close(fd) + sf.write(temp_path, y.T, sr) # soundfile expects (frames, channels) + use_path = temp_path + + # Run separation: emit only the vocals, with your exact filename + sep = Separator( + output_dir=str(dst.parent), + output_format=(dst.suffix.lstrip(".") or "wav"), + output_single_stem="Vocals", + model_file_dir="ckpts/roformer/" #model_bs_roformer_ep_317_sdr_12.9755.ckpt" + ) + sep.load_model() + out_files = sep.separate(use_path, {"Vocals": dst.stem}) + + out = Path(out_files[0]) + return str(out if out.is_absolute() else (dst.parent / out)) + finally: + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + torch.cuda.empty_cache() + gc.collect() + torch.set_default_device(default_device) + +# Example: +# final = extract_vocals("in/clip.mp3", "out/vocals.wav") +# print(final) + diff --git a/preprocessing/matanyone/app.py b/preprocessing/matanyone/app.py index 151d9bef4..a5d557099 100644 --- a/preprocessing/matanyone/app.py +++ b/preprocessing/matanyone/app.py @@ -7,7 +7,6 @@ # import ffmpeg import imageio from PIL import Image - import cv2 import torch import torch.nn.functional as F @@ -22,6 +21,7 @@ from .matanyone.inference.inference_core import InferenceCore from .matanyone_wrapper import matanyone from shared.utils.audio_video import save_video, save_image +from mmgp import offload arg_device = "cuda" arg_sam_model_type="vit_h" @@ -33,6 +33,8 @@ matanyone_in_GPU = False bfloat16_supported = False # SAM generator +import copy + class MaskGenerator(): def __init__(self, sam_checkpoint, device): global args_device @@ -89,6 +91,7 @@ def get_frames_from_image(image_input, image_state): "last_frame_numer": 0, "fps": None } + image_info = "Image Name: N/A,\nFPS: N/A,\nTotal Frames: {},\nImage Size:{}".format(len(frames), image_size) set_image_encoder_patch() select_SAM() @@ -537,7 +540,7 @@ def video_matting(video_state,video_input, end_slider, matting_type, interactive file_name = ".".join(file_name.split(".")[:-1]) from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, cleanup_temp_audio_files - source_audio_tracks, audio_metadata = extract_audio_tracks(video_input) + source_audio_tracks, audio_metadata = extract_audio_tracks(video_input, verbose= offload.default_verboseLevel ) output_fg_path = f"./mask_outputs/{file_name}_fg.mp4" output_fg_temp_path = f"./mask_outputs/{file_name}_fg_tmp.mp4" if len(source_audio_tracks) == 0: @@ -677,7 +680,6 @@ def load_unload_models(selected): } # os.path.join('.') - from mmgp import offload # sam_checkpoint = load_file_from_url(sam_checkpoint_url_dict[arg_sam_model_type], ".") sam_checkpoint = None @@ -695,7 +697,8 @@ def load_unload_models(selected): model.samcontroler.sam_controler.model.to("cpu").to(torch.bfloat16).to(arg_device) model_in_GPU = True from .matanyone.model.matanyone import MatAnyone - matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") + # matanyone_model = MatAnyone.from_pretrained("PeiqingYang/MatAnyone") + matanyone_model = MatAnyone.from_pretrained("ckpts/mask") # pipe ={"mat" : matanyone_model, "sam" :model.samcontroler.sam_controler.model } # offload.profile(pipe) matanyone_model = matanyone_model.to("cpu").eval() @@ -717,27 +720,33 @@ def load_unload_models(selected): def get_vmc_event_handler(): return load_unload_models -def export_to_vace_video_input(foreground_video_output): - gr.Info("Masked Video Input transferred to Vace For Inpainting") - return "V#" + str(time.time()), foreground_video_output - -def export_image(image_refs, image_output): - gr.Info("Masked Image transferred to Current Video") +def export_image(state, image_output): + ui_settings = get_current_model_settings(state) + image_refs = ui_settings["image_refs"] if image_refs == None: image_refs =[] image_refs.append( image_output) - return image_refs + ui_settings["image_refs"] = image_refs + gr.Info("Masked Image transferred to Current Image Generator") + return time.time() + +def export_image_mask(state, image_input, image_mask): + ui_settings = get_current_model_settings(state) + ui_settings["image_guide"] = Image.fromarray(image_input) + ui_settings["image_mask"] = image_mask + + gr.Info("Input Image & Mask transferred to Current Image Generator") + return time.time() -def export_image_mask(image_input, image_mask): - gr.Info("Input Image & Mask transferred to Current Video") - return Image.fromarray(image_input), image_mask +def export_to_current_video_engine(state, foreground_video_output, alpha_video_output): + ui_settings = get_current_model_settings(state) + ui_settings["video_guide"] = foreground_video_output + ui_settings["video_mask"] = alpha_video_output -def export_to_current_video_engine( foreground_video_output, alpha_video_output): gr.Info("Original Video and Full Mask have been transferred") - # return "MV#" + str(time.time()), foreground_video_output, alpha_video_output - return foreground_video_output, alpha_video_output + return time.time() def teleport_to_video_tab(tab_state): @@ -746,15 +755,29 @@ def teleport_to_video_tab(tab_state): return gr.Tabs(selected="video_gen") -def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): +def display(tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings_fn): #, vace_video_input, vace_image_input, vace_video_mask, vace_image_mask, vace_image_refs): # my_tab.select(fn=load_unload_models, inputs=[], outputs=[]) - global image_output_codec, video_output_codec + global image_output_codec, video_output_codec, get_current_model_settings + get_current_model_settings = get_current_model_settings_fn image_output_codec = server_config.get("image_output_codec", None) video_output_codec = server_config.get("video_output_codec", None) media_url = "https://github.com/pq-yang/MatAnyone/releases/download/media/" + click_brush_js = """ + () => { + setTimeout(() => { + const brushButton = document.querySelector('button[aria-label="Brush"]'); + if (brushButton) { + brushButton.click(); + console.log('Brush button clicked'); + } else { + console.log('Brush button not found'); + } + }, 1000); + } """ + # download assets gr.Markdown("Mast Edition is provided by MatAnyone and VRAM optimized by DeepBeepMeep") @@ -871,7 +894,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, template_frame = gr.Image(label="Start Frame", type="pil",interactive=True, elem_id="template_frame", visible=False, elem_classes="image") with gr.Row(): clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False, min_width=100) - add_mask_button = gr.Button(value="Set Mask", interactive=True, visible=False, min_width=100) + add_mask_button = gr.Button(value="Add Mask", interactive=True, visible=False, min_width=100) remove_mask_button = gr.Button(value="Remove Mask", interactive=True, visible=False, min_width=100) # no use matting_button = gr.Button(value="Generate Video Matting", interactive=True, visible=False, min_width=100) with gr.Row(): @@ -892,7 +915,7 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, with gr.Row(visible= True): export_to_current_video_engine_btn = gr.Button("Export to Control Video Input and Video Mask Input", visible= False) - export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [foreground_video_output, alpha_video_output], outputs= [vace_video_input, vace_video_mask]).then( #video_prompt_video_guide_trigger, + export_to_current_video_engine_btn.click( fn=export_to_current_video_engine, inputs= [state, foreground_video_output, alpha_video_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) @@ -1089,10 +1112,10 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, # with gr.Column(scale=2, visible= True): export_image_mask_btn = gr.Button(value="Set to Control Image & Mask", visible=False, elem_classes="new_button") - export_image_btn.click( fn=export_image, inputs= [vace_image_refs, foreground_image_output], outputs= [vace_image_refs]).then( #video_prompt_video_guide_trigger, - fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) - export_image_mask_btn.click( fn=export_image_mask, inputs= [image_input, alpha_image_output], outputs= [vace_image_input, vace_image_mask]).then( #video_prompt_video_guide_trigger, + export_image_btn.click( fn=export_image, inputs= [state, foreground_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]) + export_image_mask_btn.click( fn=export_image_mask, inputs= [state, image_input, alpha_image_output], outputs= [refresh_form_trigger]).then( #video_prompt_video_guide_trigger, + fn=teleport_to_video_tab, inputs= [tab_state], outputs= [tabs]).then(fn=None, inputs=None, outputs=None, js=click_brush_js) # first step: get the image information extract_frames_button.click( @@ -1148,5 +1171,21 @@ def display(tabs, tab_state, server_config, vace_video_input, vace_image_input, outputs=[foreground_image_output, alpha_image_output,foreground_image_output, alpha_image_output,bbox_info, export_image_btn, export_image_mask_btn] ) - - + nada = gr.State({}) + # clear input + gr.on( + triggers=[image_input.clear], #image_input.change, + fn=restart, + inputs=[], + outputs=[ + image_state, + interactive_state, + click_state, + foreground_image_output, alpha_image_output, + template_frame, + image_selection_slider, image_selection_slider, track_pause_number_slider,point_prompt, export_image_btn, export_image_mask_btn, bbox_info, clear_button_click, + add_mask_button, matting_button, template_frame, foreground_image_output, alpha_image_output, remove_mask_button, export_image_btn, export_image_mask_btn, mask_dropdown, nada, step2_title + ], + queue=False, + show_progress=False) + diff --git a/preprocessing/matanyone/matanyone/model/utils/memory_utils.py b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py index 8c857ea9c..8166a9ffe 100644 --- a/preprocessing/matanyone/matanyone/model/utils/memory_utils.py +++ b/preprocessing/matanyone/matanyone/model/utils/memory_utils.py @@ -2,7 +2,6 @@ import torch from typing import Optional, Union, Tuple - # @torch.jit.script def get_similarity(mk: torch.Tensor, ms: torch.Tensor, @@ -59,6 +58,7 @@ def get_similarity(mk: torch.Tensor, del two_ab # similarity = (-a_sq + two_ab) + similarity =similarity.float() if ms is not None: similarity *= ms similarity /= math.sqrt(CK) diff --git a/preprocessing/matanyone/matanyone_wrapper.py b/preprocessing/matanyone/matanyone_wrapper.py index 292465a8c..7a729abee 100644 --- a/preprocessing/matanyone/matanyone_wrapper.py +++ b/preprocessing/matanyone/matanyone_wrapper.py @@ -73,5 +73,5 @@ def matanyone(processor, frames_np, mask, r_erode=0, r_dilate=0, n_warmup=10): if ti > (n_warmup-1): frames.append((com_np*255).astype(np.uint8)) phas.append((pha*255).astype(np.uint8)) - + # phas.append(np.clip(pha * 255, 0, 255).astype(np.uint8)) return frames, phas \ No newline at end of file diff --git a/preprocessing/speakers_separator.py b/preprocessing/speakers_separator.py index 79cde9b1e..d5f256335 100644 --- a/preprocessing/speakers_separator.py +++ b/preprocessing/speakers_separator.py @@ -100,7 +100,7 @@ def __init__(self, hf_token: str = None, local_model_path: str = None, self.hf_token = hf_token self._overlap_pipeline = None - def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]: + def separate_audio(self, audio_path: str, output1, output2, audio_original_path: str = None ) -> Dict[str, str]: """Optimized main separation function with memory management.""" xprint("Starting optimized audio separation...") self._current_audio_path = os.path.abspath(audio_path) @@ -128,7 +128,11 @@ def separate_audio(self, audio_path: str, output1, output2 ) -> Dict[str, str]: gc.collect() # Save outputs efficiently - output_paths = self._save_outputs_optimized(waveform, final_masks, sample_rate, audio_path, output1, output2) + if audio_original_path is None: + waveform_original = waveform + else: + waveform_original, sample_rate = self.load_audio(audio_original_path) + output_paths = self._save_outputs_optimized(waveform_original, final_masks, sample_rate, audio_path, output1, output2) return output_paths @@ -835,7 +839,7 @@ def print_summary(self, audio_path: str): for turn, _, speaker in diarization.itertracks(yield_label=True): xprint(f"{speaker}: {turn.start:.1f}s - {turn.end:.1f}s") -def extract_dual_audio(audio, output1, output2, verbose = False): +def extract_dual_audio(audio, output1, output2, verbose = False, audio_original = None): global verbose_output verbose_output = verbose separator = OptimizedPyannote31SpeakerSeparator( @@ -848,7 +852,7 @@ def extract_dual_audio(audio, output1, output2, verbose = False): import time start_time = time.time() - outputs = separator.separate_audio(audio, output1, output2) + outputs = separator.separate_audio(audio, output1, output2, audio_original) elapsed_time = time.time() - start_time xprint(f"\n=== SUCCESS (completed in {elapsed_time:.2f}s) ===") diff --git a/requirements.txt b/requirements.txt index a2e1fde85..9b2c6eeb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,19 +20,24 @@ soundfile mutagen pyloudnorm librosa==0.11.0 +speechbrain==1.0.3 +audio-separator==0.36.1 # UI & interaction -gradio==5.23.0 +gradio==5.29.0 dashscope loguru # Vision & segmentation -opencv-python>=4.9.0.80 +opencv-python>=4.12.0.88 segment-anything rembg[gpu]==2.0.65 -onnxruntime-gpu +onnxruntime-gpu==1.22 decord timm +insightface @ https://github.com/deepbeepmeep/insightface/raw/refs/heads/master/wheels/insightface-0.7.3-cp310-cp310-win_amd64.whl ; sys_platform == "win32" and python_version == "3.10" +insightface==0.7.3 ; sys_platform == "linux" +facexlib==0.3.0 # Config & orchestration omegaconf @@ -43,14 +48,14 @@ pydantic==2.10.6 # Math & modeling torchdiffeq>=0.2.5 tensordict>=0.6.1 -mmgp==3.5.10 +mmgp==3.6.0 peft==0.15.0 matplotlib # Utilities ftfy piexif -pynvml +nvidia-ml-py misaki # Optional / commented out diff --git a/run-docker-cuda-deb.sh b/run-docker-cuda-deb.sh new file mode 100755 index 000000000..b35e9cc6b --- /dev/null +++ b/run-docker-cuda-deb.sh @@ -0,0 +1,210 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ───────────────────────── helpers ───────────────────────── + +install_nvidia_smi_if_missing() { + if command -v nvidia-smi &>/dev/null; then + return + fi + + echo "⚠️ nvidia-smi not found. Installing nvidia-utils…" + if [ "$EUID" -ne 0 ]; then + SUDO='sudo' + else + SUDO='' + fi + + $SUDO apt-get update + $SUDO apt-get install -y nvidia-utils-535 || $SUDO apt-get install -y nvidia-utils + + if ! command -v nvidia-smi &>/dev/null; then + echo "❌ Failed to install nvidia-smi. Cannot detect GPU architecture." + exit 1 + fi + echo "✅ nvidia-smi installed successfully." +} + +detect_gpu_name() { + install_nvidia_smi_if_missing + nvidia-smi --query-gpu=name --format=csv,noheader,nounits | head -1 +} + +map_gpu_to_arch() { + local name="$1" + case "$name" in + *"RTX 50"* | *"5090"* | *"5080"* | *"5070"*) echo "12.0" ;; + *"H100"* | *"H800"*) echo "9.0" ;; + *"RTX 40"* | *"4090"* | *"4080"* | *"4070"* | *"4060"*) echo "8.9" ;; + *"RTX 30"* | *"3090"* | *"3080"* | *"3070"* | *"3060"*) echo "8.6" ;; + *"A100"* | *"A800"* | *"A40"*) echo "8.0" ;; + *"Tesla V100"*) echo "7.0" ;; + *"RTX 20"* | *"2080"* | *"2070"* | *"2060"* | *"Titan RTX"*) echo "7.5" ;; + *"GTX 16"* | *"1660"* | *"1650"*) echo "7.5" ;; + *"GTX 10"* | *"1080"* | *"1070"* | *"1060"* | *"Tesla P100"*) echo "6.1" ;; + *"Tesla K80"* | *"Tesla K40"*) echo "3.7" ;; + *) + echo "❌ Unknown GPU model: $name" + echo "Please update the map_gpu_to_arch function for this model." + exit 1 + ;; + esac +} + +get_gpu_vram() { + install_nvidia_smi_if_missing + # Get VRAM in MB, convert to GB + local vram_mb=$(nvidia-smi --query-gpu=memory.total --format=csv,noheader,nounits | head -1) + echo $((vram_mb / 1024)) +} + +map_gpu_to_profile() { + local name="$1" + local vram_gb="$2" + + # WanGP Profile descriptions from the actual UI: + # Profile 1: HighRAM_HighVRAM - 48GB+ RAM, 24GB+ VRAM (fastest for short videos, RTX 3090/4090) + # Profile 2: HighRAM_LowVRAM - 48GB+ RAM, 12GB+ VRAM (recommended, most versatile) + # Profile 3: LowRAM_HighVRAM - 32GB+ RAM, 24GB+ VRAM (RTX 3090/4090 with limited RAM) + # Profile 4: LowRAM_LowVRAM - 32GB+ RAM, 12GB+ VRAM (default, little VRAM or longer videos) + # Profile 5: VerylowRAM_LowVRAM - 16GB+ RAM, 10GB+ VRAM (fail safe, slow but works) + + case "$name" in + # High-end data center GPUs with 24GB+ VRAM - Profile 1 (HighRAM_HighVRAM) + *"RTX 50"* | *"5090"* | *"A100"* | *"A800"* | *"H100"* | *"H800"*) + if [ "$vram_gb" -ge 24 ]; then + echo "1" # HighRAM_HighVRAM - fastest for short videos + else + echo "2" # HighRAM_LowVRAM - most versatile + fi + ;; + # High-end consumer GPUs (RTX 3090/4090) - Profile 1 or 3 + *"RTX 40"* | *"4090"* | *"RTX 30"* | *"3090"*) + if [ "$vram_gb" -ge 24 ]; then + echo "3" # LowRAM_HighVRAM - good for limited RAM systems + else + echo "2" # HighRAM_LowVRAM - most versatile + fi + ;; + # Mid-range GPUs (RTX 3070/3080/4070/4080) - Profile 2 recommended + *"4080"* | *"4070"* | *"3080"* | *"3070"* | *"RTX 20"* | *"2080"* | *"2070"*) + if [ "$vram_gb" -ge 12 ]; then + echo "2" # HighRAM_LowVRAM - recommended for these GPUs + else + echo "4" # LowRAM_LowVRAM - default for little VRAM + fi + ;; + # Lower-end GPUs with 6-12GB VRAM - Profile 4 or 5 + *"4060"* | *"3060"* | *"2060"* | *"GTX 16"* | *"1660"* | *"1650"*) + if [ "$vram_gb" -ge 10 ]; then + echo "4" # LowRAM_LowVRAM - default + else + echo "5" # VerylowRAM_LowVRAM - fail safe + fi + ;; + # Older/lower VRAM GPUs - Profile 5 (fail safe) + *"GTX 10"* | *"1080"* | *"1070"* | *"1060"* | *"Tesla"*) + echo "5" # VerylowRAM_LowVRAM - fail safe + ;; + *) + echo "4" # LowRAM_LowVRAM - default fallback + ;; + esac +} + +# ───────────────────────── main ──────────────────────────── + +echo "🔧 NVIDIA CUDA Setup Check:" + +# NVIDIA driver check +if command -v nvidia-smi &>/dev/null; then + DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader,nounits | head -1) + echo "✅ NVIDIA Driver: $DRIVER_VERSION" + + # Quick CUDA 12.4 compatibility check + if [[ "$DRIVER_VERSION" =~ ^([0-9]+) ]]; then + MAJOR=${BASH_REMATCH[1]} + if [ "$MAJOR" -lt 520 ]; then + echo "⚠️ Driver $DRIVER_VERSION may not support CUDA 12.4 (need 520+)" + fi + fi +else + echo "❌ nvidia-smi not found - no NVIDIA drivers" + exit 1 +fi + +GPU_NAME=$(detect_gpu_name) +echo "🔍 Detected GPU: $GPU_NAME" + +VRAM_GB=$(get_gpu_vram) +echo "🧠 Detected VRAM: ${VRAM_GB}GB" + +CUDA_ARCH=$(map_gpu_to_arch "$GPU_NAME") +echo "🚀 Using CUDA architecture: $CUDA_ARCH" + +PROFILE=$(map_gpu_to_profile "$GPU_NAME" "$VRAM_GB") +echo "⚙️ Selected profile: $PROFILE" + +docker build --build-arg CUDA_ARCHITECTURES="$CUDA_ARCH" -t deepbeepmeep/wan2gp . + +# sudo helper for later commands +if [ "$EUID" -ne 0 ]; then + SUDO='sudo' +else + SUDO='' +fi + +# Ensure NVIDIA runtime is available +if ! docker info 2>/dev/null | grep -q 'Runtimes:.*nvidia'; then + echo "⚠️ NVIDIA Docker runtime not found. Installing nvidia-docker2…" + $SUDO apt-get update + $SUDO apt-get install -y curl ca-certificates gnupg + curl -s -L https://nvidia.github.io/nvidia-docker/gpgkey | $SUDO apt-key add - + distribution=$( + . /etc/os-release + echo $ID$VERSION_ID + ) + curl -s -L https://nvidia.github.io/nvidia-docker/$distribution/nvidia-docker.list | + $SUDO tee /etc/apt/sources.list.d/nvidia-docker.list + $SUDO apt-get update + $SUDO apt-get install -y nvidia-docker2 + echo "🔄 Restarting Docker service…" + $SUDO systemctl restart docker + echo "✅ NVIDIA Docker runtime installed." +else + echo "✅ NVIDIA Docker runtime found." +fi + +# Quick NVIDIA runtime test +echo "🧪 Testing NVIDIA runtime..." +if timeout 15s docker run --rm --gpus all --runtime=nvidia nvidia/cuda:12.4-runtime-ubuntu22.04 nvidia-smi >/dev/null 2>&1; then + echo "✅ NVIDIA runtime working" +else + echo "❌ NVIDIA runtime test failed - check driver/runtime compatibility" +fi + +# Prepare cache dirs & volume mounts +cache_dirs=(numba matplotlib huggingface torch) +cache_mounts=() +for d in "${cache_dirs[@]}"; do + mkdir -p "$HOME/.cache/$d" + chmod 700 "$HOME/.cache/$d" + cache_mounts+=(-v "$HOME/.cache/$d:/home/user/.cache/$d") +done + +echo "🔧 Optimization settings:" +echo " Profile: $PROFILE" + +# Run the container +docker run --rm -it \ + --name wan2gp \ + --gpus all \ + --runtime=nvidia \ + -p 7860:7860 \ + -v "$(pwd):/workspace" \ + "${cache_mounts[@]}" \ + deepbeepmeep/wan2gp \ + --profile "$PROFILE" \ + --attention sage \ + --compile \ + --perc-reserved-mem-max 1 diff --git a/shared/RGB_factors.py b/shared/RGB_factors.py index 6e865fa7b..8a870b4b6 100644 --- a/shared/RGB_factors.py +++ b/shared/RGB_factors.py @@ -1,6 +1,6 @@ # thanks Comfyui for the rgb factors (https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py) def get_rgb_factors(model_family, model_type = None): - if model_family == "wan": + if model_family in ["wan", "qwen"]: if model_type =="ti2v_2_2": latent_channels = 48 latent_dimensions = 3 @@ -261,7 +261,7 @@ def get_rgb_factors(model_family, model_type = None): [ 0.0249, -0.0469, -0.1703] ] - latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] else: latent_rgb_factors_bias = latent_rgb_factors = None return latent_rgb_factors, latent_rgb_factors_bias \ No newline at end of file diff --git a/shared/attention.py b/shared/attention.py index a95332dc5..c31087c24 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -3,6 +3,7 @@ from importlib.metadata import version from mmgp import offload import torch.nn.functional as F +import warnings major, minor = torch.cuda.get_device_capability(None) bfloat16_supported = major >= 8 @@ -41,35 +42,57 @@ def sageattn_varlen_wrapper( except ImportError: sageattn_varlen_wrapper = None +try: + from spas_sage_attn import block_sparse_sage2_attn_cuda +except ImportError: + block_sparse_sage2_attn_cuda = None -import warnings try: - from sageattention import sageattn - from .sage2_core import sageattn as alt_sageattn, is_sage2_supported + from .sage2_core import sageattn as sageattn2, is_sage2_supported sage2_supported = is_sage2_supported() except ImportError: - sageattn = None - alt_sageattn = None + sageattn2 = None sage2_supported = False -# @torch.compiler.disable() -def sageattn_wrapper( +@torch.compiler.disable() +def sageattn2_wrapper( qkv_list, attention_length ): q,k, v = qkv_list - if True: - qkv_list = [q,k,v] - del q, k ,v - o = alt_sageattn(qkv_list, tensor_layout="NHD") - else: - o = sageattn(q, k, v, tensor_layout="NHD") - del q, k ,v + qkv_list = [q,k,v] + del q, k ,v + o = sageattn2(qkv_list, tensor_layout="NHD") + qkv_list.clear() + + return o + +try: + from sageattn import sageattn_blackwell as sageattn3 +except ImportError: + sageattn3 = None +@torch.compiler.disable() +def sageattn3_wrapper( + qkv_list, + attention_length + ): + q,k, v = qkv_list + # qkv_list = [q,k,v] + # del q, k ,v + # o = sageattn3(qkv_list, tensor_layout="NHD") + q = q.transpose(1,2) + k = k.transpose(1,2) + v = v.transpose(1,2) + o = sageattn3(q, k, v) + o = o.transpose(1,2) qkv_list.clear() return o + + + # try: # if True: # from .sage2_core import sageattn_qk_int8_pv_fp8_window_cuda @@ -94,7 +117,7 @@ def sageattn_wrapper( # return o # except ImportError: -# sageattn = sageattn_qk_int8_pv_fp8_window_cuda +# sageattn2 = sageattn_qk_int8_pv_fp8_window_cuda @torch.compiler.disable() def sdpa_wrapper( @@ -124,21 +147,33 @@ def get_attention_modes(): ret.append("xformers") if sageattn_varlen_wrapper != None: ret.append("sage") - if sageattn != None and version("sageattention").startswith("2") : + if sageattn2 != None and version("sageattention").startswith("2") : ret.append("sage2") + if block_sparse_sage2_attn_cuda != None and version("sageattention").startswith("2") : + ret.append("radial") + + if sageattn3 != None: # and version("sageattention").startswith("3") : + ret.append("sage3") return ret def get_supported_attention_modes(): ret = get_attention_modes() + major, minor = torch.cuda.get_device_capability() + if major < 10: + if "sage3" in ret: + ret.remove("sage3") + if not sage2_supported: if "sage2" in ret: ret.remove("sage2") + if "radial" in ret: + ret.remove("radial") - major, minor = torch.cuda.get_device_capability() if major < 7: if "sage" in ret: ret.remove("sage") + return ret __all__ = [ @@ -200,8 +235,9 @@ def pay_attention( if attn == "chipmunk": from src.chipmunk.modules import SparseDiffMlp, SparseDiffAttn from src.chipmunk.util import LayerCounter, GLOBAL_CONFIG + if attn == "radial": attn ="sage2" - if b > 1 and k_lens != None and attn in ("sage2", "sdpa"): + if b > 1 and k_lens != None and attn in ("sage2", "sage3", "sdpa"): assert attention_mask == None # Poor's man var k len attention assert q_lens == None @@ -234,7 +270,7 @@ def pay_attention( q_chunks, k_chunks, v_chunks = None, None, None o = torch.cat(o, dim = 0) return o - elif (q_lens != None or k_lens != None) and attn in ("sage2", "sdpa"): + elif (q_lens != None or k_lens != None) and attn in ("sage2", "sage3", "sdpa"): assert b == 1 szq = q_lens[0].item() if q_lens != None else lq szk = k_lens[0].item() if k_lens != None else lk @@ -284,13 +320,19 @@ def pay_attention( max_seqlen_q=lq, max_seqlen_kv=lk, ).unflatten(0, (b, lq)) + elif attn=="sage3": + import math + if cross_attn or True: + qkv_list = [q,k,v] + del q,k,v + x = sageattn3_wrapper(qkv_list, lq) elif attn=="sage2": import math if cross_attn or True: qkv_list = [q,k,v] del q,k,v - x = sageattn_wrapper(qkv_list, lq) #.unsqueeze(0) + x = sageattn2_wrapper(qkv_list, lq) #.unsqueeze(0) # else: # layer = offload.shared_state["layer"] # embed_sizes = offload.shared_state["embed_sizes"] diff --git a/shared/convert/convert_diffusers_to_flux.py b/shared/convert/convert_diffusers_to_flux.py new file mode 100644 index 000000000..608b1765e --- /dev/null +++ b/shared/convert/convert_diffusers_to_flux.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +""" +Convert a Flux model from Diffusers (folder or single-file) into the original +single-file Flux transformer checkpoint used by Black Forest Labs / ComfyUI. + +Input : /path/to/diffusers (root or .../transformer) OR /path/to/*.safetensors (single file) +Output : /path/to/flux1-your-model.safetensors (transformer only) + +Usage: + python diffusers_to_flux_transformer.py /path/to/diffusers /out/flux1-dev.safetensors + python diffusers_to_flux_transformer.py /path/to/diffusion_pytorch_model.safetensors /out/flux1-dev.safetensors + # optional quantization: + # --fp8 (float8_e4m3fn, simple) + # --fp8-scaled (scaled float8 for 2D weights; adds .scale_weight tensors) +""" + +import argparse +import json +from pathlib import Path +from collections import OrderedDict + +import torch +from safetensors import safe_open +import safetensors.torch +from tqdm import tqdm + + +def parse_args(): + ap = argparse.ArgumentParser() + ap.add_argument("diffusers_path", type=str, + help="Path to Diffusers checkpoint folder OR a single .safetensors file.") + ap.add_argument("output_path", type=str, + help="Output .safetensors path for the Flux transformer.") + ap.add_argument("--fp8", action="store_true", + help="Experimental: write weights as float8_e4m3fn via stochastic rounding (transformer only).") + ap.add_argument("--fp8-scaled", action="store_true", + help="Experimental: scaled float8_e4m3fn for 2D weight tensors; adds .scale_weight tensors.") + return ap.parse_args() + + +# Mapping from original Flux keys -> list of Diffusers keys (per block where applicable). +DIFFUSERS_MAP = { + # global embeds + "time_in.in_layer.weight": ["time_text_embed.timestep_embedder.linear_1.weight"], + "time_in.in_layer.bias": ["time_text_embed.timestep_embedder.linear_1.bias"], + "time_in.out_layer.weight": ["time_text_embed.timestep_embedder.linear_2.weight"], + "time_in.out_layer.bias": ["time_text_embed.timestep_embedder.linear_2.bias"], + + "vector_in.in_layer.weight": ["time_text_embed.text_embedder.linear_1.weight"], + "vector_in.in_layer.bias": ["time_text_embed.text_embedder.linear_1.bias"], + "vector_in.out_layer.weight": ["time_text_embed.text_embedder.linear_2.weight"], + "vector_in.out_layer.bias": ["time_text_embed.text_embedder.linear_2.bias"], + + "guidance_in.in_layer.weight": ["time_text_embed.guidance_embedder.linear_1.weight"], + "guidance_in.in_layer.bias": ["time_text_embed.guidance_embedder.linear_1.bias"], + "guidance_in.out_layer.weight": ["time_text_embed.guidance_embedder.linear_2.weight"], + "guidance_in.out_layer.bias": ["time_text_embed.guidance_embedder.linear_2.bias"], + + "txt_in.weight": ["context_embedder.weight"], + "txt_in.bias": ["context_embedder.bias"], + "img_in.weight": ["x_embedder.weight"], + "img_in.bias": ["x_embedder.bias"], + + # dual-stream (image/text) blocks + "double_blocks.().img_mod.lin.weight": ["norm1.linear.weight"], + "double_blocks.().img_mod.lin.bias": ["norm1.linear.bias"], + "double_blocks.().txt_mod.lin.weight": ["norm1_context.linear.weight"], + "double_blocks.().txt_mod.lin.bias": ["norm1_context.linear.bias"], + + "double_blocks.().img_attn.qkv.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight"], + "double_blocks.().img_attn.qkv.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias"], + "double_blocks.().txt_attn.qkv.weight": ["attn.add_q_proj.weight", "attn.add_k_proj.weight", "attn.add_v_proj.weight"], + "double_blocks.().txt_attn.qkv.bias": ["attn.add_q_proj.bias", "attn.add_k_proj.bias", "attn.add_v_proj.bias"], + + "double_blocks.().img_attn.norm.query_norm.scale": ["attn.norm_q.weight"], + "double_blocks.().img_attn.norm.key_norm.scale": ["attn.norm_k.weight"], + "double_blocks.().txt_attn.norm.query_norm.scale": ["attn.norm_added_q.weight"], + "double_blocks.().txt_attn.norm.key_norm.scale": ["attn.norm_added_k.weight"], + + "double_blocks.().img_mlp.0.weight": ["ff.net.0.proj.weight"], + "double_blocks.().img_mlp.0.bias": ["ff.net.0.proj.bias"], + "double_blocks.().img_mlp.2.weight": ["ff.net.2.weight"], + "double_blocks.().img_mlp.2.bias": ["ff.net.2.bias"], + + "double_blocks.().txt_mlp.0.weight": ["ff_context.net.0.proj.weight"], + "double_blocks.().txt_mlp.0.bias": ["ff_context.net.0.proj.bias"], + "double_blocks.().txt_mlp.2.weight": ["ff_context.net.2.weight"], + "double_blocks.().txt_mlp.2.bias": ["ff_context.net.2.bias"], + + "double_blocks.().img_attn.proj.weight": ["attn.to_out.0.weight"], + "double_blocks.().img_attn.proj.bias": ["attn.to_out.0.bias"], + "double_blocks.().txt_attn.proj.weight": ["attn.to_add_out.weight"], + "double_blocks.().txt_attn.proj.bias": ["attn.to_add_out.bias"], + + # single-stream blocks + "single_blocks.().modulation.lin.weight": ["norm.linear.weight"], + "single_blocks.().modulation.lin.bias": ["norm.linear.bias"], + "single_blocks.().linear1.weight": ["attn.to_q.weight", "attn.to_k.weight", "attn.to_v.weight", "proj_mlp.weight"], + "single_blocks.().linear1.bias": ["attn.to_q.bias", "attn.to_k.bias", "attn.to_v.bias", "proj_mlp.bias"], + "single_blocks.().norm.query_norm.scale": ["attn.norm_q.weight"], + "single_blocks.().norm.key_norm.scale": ["attn.norm_k.weight"], + "single_blocks.().linear2.weight": ["proj_out.weight"], + "single_blocks.().linear2.bias": ["proj_out.bias"], + + # final + "final_layer.linear.weight": ["proj_out.weight"], + "final_layer.linear.bias": ["proj_out.bias"], + # these two are built from norm_out.linear.{weight,bias} by swapping [shift,scale] -> [scale,shift] + "final_layer.adaLN_modulation.1.weight": ["norm_out.linear.weight"], + "final_layer.adaLN_modulation.1.bias": ["norm_out.linear.bias"], +} + + +class DiffusersSource: + """ + Uniform interface over: + 1) Folder with index JSON + shards + 2) Folder with exactly one .safetensors (no index) + 3) Single .safetensors file + Provides .has(key), .get(key)->Tensor, .base_keys (keys with 'model.' stripped for scanning) + """ + + POSSIBLE_PREFIXES = ["", "model."] # try in this order + + def __init__(self, path: Path): + p = Path(path) + if p.is_dir(): + # use 'transformer' subfolder if present + if (p / "transformer").is_dir(): + p = p / "transformer" + self._init_from_dir(p) + elif p.is_file() and p.suffix == ".safetensors": + self._init_from_single_file(p) + else: + raise FileNotFoundError(f"Invalid path: {p}") + + # ---------- common helpers ---------- + + @staticmethod + def _strip_prefix(k: str) -> str: + return k[6:] if k.startswith("model.") else k + + def _resolve(self, want: str): + """ + Return the actual stored key matching `want` by trying known prefixes. + """ + for pref in self.POSSIBLE_PREFIXES: + k = pref + want + if k in self._all_keys: + return k + return None + + def has(self, want: str) -> bool: + return self._resolve(want) is not None + + def get(self, want: str) -> torch.Tensor: + real_key = self._resolve(want) + if real_key is None: + raise KeyError(f"Missing key: {want}") + return self._get_by_real_key(real_key).to("cpu") + + @property + def base_keys(self): + # keys without 'model.' prefix for scanning + return [self._strip_prefix(k) for k in self._all_keys] + + # ---------- modes ---------- + + def _init_from_single_file(self, file_path: Path): + self._mode = "single" + self._file = file_path + self._handle = safe_open(file_path, framework="pt", device="cpu") + self._all_keys = list(self._handle.keys()) + + def _get_by_real_key(real_key: str): + return self._handle.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + + def _init_from_dir(self, dpath: Path): + index_json = dpath / "diffusion_pytorch_model.safetensors.index.json" + if index_json.exists(): + with open(index_json, "r", encoding="utf-8") as f: + index = json.load(f) + weight_map = index["weight_map"] # full mapping + self._mode = "sharded" + self._dpath = dpath + self._weight_map = {k: dpath / v for k, v in weight_map.items()} + self._all_keys = list(self._weight_map.keys()) + self._open_handles = {} + + def _get_by_real_key(real_key: str): + fpath = self._weight_map[real_key] + h = self._open_handles.get(fpath) + if h is None: + h = safe_open(fpath, framework="pt", device="cpu") + self._open_handles[fpath] = h + return h.get_tensor(real_key) + + self._get_by_real_key = _get_by_real_key + return + + # no index: try exactly one safetensors in folder + files = sorted(dpath.glob("*.safetensors")) + if len(files) != 1: + raise FileNotFoundError( + f"No index found and {dpath} does not contain exactly one .safetensors file." + ) + self._init_from_single_file(files[0]) + + +def main(): + args = parse_args() + src = DiffusersSource(Path(args.diffusers_path)) + + # Count blocks by scanning base keys (with any 'model.' prefix removed) + num_dual = 0 + num_single = 0 + for k in src.base_keys: + if k.startswith("transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_dual = max(num_dual, i + 1) + except Exception: + pass + elif k.startswith("single_transformer_blocks."): + try: + i = int(k.split(".")[1]) + num_single = max(num_single, i + 1) + except Exception: + pass + print(f"Found {num_dual} dual-stream blocks, {num_single} single-stream blocks") + + # Swap [shift, scale] -> [scale, shift] (weights are concatenated along dim=0) + def swap_scale_shift(vec: torch.Tensor) -> torch.Tensor: + shift, scale = vec.chunk(2, dim=0) + return torch.cat([scale, shift], dim=0) + + orig = {} + + # Per-block (dual) + for b in range(num_dual): + prefix = f"transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("double_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Per-block (single) + for b in range(num_single): + prefix = f"single_transformer_blocks.{b}." + for okey, dvals in DIFFUSERS_MAP.items(): + if not okey.startswith("single_blocks."): + continue + dkeys = [prefix + v for v in dvals] + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey.replace("()", str(b))] = src.get(dkeys[0]) + else: + orig[okey.replace("()", str(b))] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Globals (non-block) + for okey, dvals in DIFFUSERS_MAP.items(): + if okey.startswith(("double_blocks.", "single_blocks.")): + continue + dkeys = dvals + if not all(src.has(k) for k in dkeys): + continue + if len(dkeys) == 1: + orig[okey] = src.get(dkeys[0]) + else: + orig[okey] = torch.cat([src.get(k) for k in dkeys], dim=0) + + # Fix final_layer.adaLN_modulation.1.{weight,bias} by swapping scale/shift halves + if "final_layer.adaLN_modulation.1.weight" in orig: + orig["final_layer.adaLN_modulation.1.weight"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.weight"] + ) + if "final_layer.adaLN_modulation.1.bias" in orig: + orig["final_layer.adaLN_modulation.1.bias"] = swap_scale_shift( + orig["final_layer.adaLN_modulation.1.bias"] + ) + + # Optional FP8 variants (experimental; not required for ComfyUI/BFL) + if args.fp8 or args.fp8_scaled: + dtype = torch.float8_e4m3fn # noqa + minv, maxv = torch.finfo(dtype).min, torch.finfo(dtype).max + + def stochastic_round_to(t): + t = t.float().clamp(minv, maxv) + lower = torch.floor(t * 256) / 256 + upper = torch.ceil(t * 256) / 256 + prob = torch.where(upper != lower, (t - lower) / (upper - lower), torch.zeros_like(t)) + rnd = torch.rand_like(t) + out = torch.where(rnd < prob, upper, lower) + return out.to(dtype) + + def scale_to_8bit(weight, target_max=416.0): + absmax = weight.abs().max() + scale = absmax / target_max if absmax > 0 else torch.tensor(1.0) + scaled = (weight / scale).clamp(minv, maxv).to(dtype) + return scaled, scale + + scales = {} + for k in tqdm(list(orig.keys()), desc="Quantizing to fp8"): + t = orig[k] + if args.fp8: + orig[k] = stochastic_round_to(t) + else: + if k.endswith(".weight") and t.dim() == 2: + qt, s = scale_to_8bit(t) + orig[k] = qt + scales[k[:-len(".weight")] + ".scale_weight"] = s + else: + orig[k] = t.clamp(minv, maxv).to(dtype) + if args.fp8_scaled: + orig.update(scales) + orig["scaled_fp8"] = torch.tensor([], dtype=dtype) + else: + # Default: save in bfloat16 + for k in list(orig.keys()): + orig[k] = orig[k].to(torch.bfloat16).cpu() + + out_path = Path(args.output_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + meta = OrderedDict() + meta["format"] = "pt" + meta["modelspec.date"] = __import__("datetime").date.today().strftime("%Y-%m-%d") + print(f"Saving transformer to: {out_path}") + safetensors.torch.save_file(orig, str(out_path), metadata=meta) + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/shared/gradio/gallery.py b/shared/gradio/gallery.py new file mode 100644 index 000000000..757da3422 --- /dev/null +++ b/shared/gradio/gallery.py @@ -0,0 +1,532 @@ +from __future__ import annotations +import os, io, tempfile, mimetypes +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Literal + +import gradio as gr +import PIL +import time +from PIL import Image as PILImage + +FilePath = str +ImageLike = Union["PIL.Image.Image", Any] + +IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"} +VIDEO_EXTS = {".mp4", ".mov", ".avi", ".mkv", ".webm", ".m4v", ".mpeg", ".mpg", ".ogv"} + +def get_state(state): + return state if isinstance(state, dict) else state.value + +def get_list( objs): + if objs is None: + return [] + return [ obj[0] if isinstance(obj, tuple) else obj for obj in objs] + +def record_last_action(st, last_action): + st["last_action"] = last_action + st["last_time"] = time.time() +class AdvancedMediaGallery: + def __init__( + self, + label: str = "Media", + *, + media_mode: Literal["image", "video"] = "image", + height = None, + columns: Union[int, Tuple[int, ...]] = 6, + show_label: bool = True, + initial: Optional[Sequence[Union[FilePath, ImageLike]]] = None, + elem_id: Optional[str] = None, + elem_classes: Optional[Sequence[str]] = ("adv-media-gallery",), + accept_filter: bool = True, # restrict Add-button dialog to allowed extensions + single_image_mode: bool = False, # start in single-image mode (Add replaces) + ): + assert media_mode in ("image", "video") + self.label = label + self.media_mode = media_mode + self.height = height + self.columns = columns + self.show_label = show_label + self.elem_id = elem_id + self.elem_classes = list(elem_classes) if elem_classes else None + self.accept_filter = accept_filter + + items = self._normalize_initial(initial or [], media_mode) + + # Components (filled on mount) + self.container: Optional[gr.Column] = None + self.gallery: Optional[gr.Gallery] = None + self.upload_btn: Optional[gr.UploadButton] = None + self.btn_remove: Optional[gr.Button] = None + self.btn_left: Optional[gr.Button] = None + self.btn_right: Optional[gr.Button] = None + self.btn_clear: Optional[gr.Button] = None + + # Single dict state + self.state: Optional[gr.State] = None + self._initial_state: Dict[str, Any] = { + "items": items, + "selected": (len(items) - 1) if items else 0, # None, + "single": bool(single_image_mode), + "mode": self.media_mode, + "last_action": "", + } + + # ---------------- helpers ---------------- + + def _normalize_initial(self, items: Sequence[Union[FilePath, ImageLike]], mode: str) -> List[Any]: + out: List[Any] = [] + if mode == "image": + for it in items: + p = self._ensure_image_item(it) + if p is not None: + out.append(p) + else: + for it in items: + if isinstance(item, tuple): item = item[0] + if isinstance(it, str) and self._is_video_path(it): + out.append(os.path.abspath(it)) + return out + + def _ensure_image_item(self, item: Union[FilePath, ImageLike]) -> Optional[Any]: + # Accept a path to an image, or a PIL.Image/np.ndarray -> save temp PNG and return its path + if isinstance(item, tuple): item = item[0] + if isinstance(item, str): + return os.path.abspath(item) if self._is_image_path(item) else None + if PILImage is None: + return None + try: + if isinstance(item, PILImage.Image): + img = item + else: + import numpy as np # type: ignore + if isinstance(item, np.ndarray): + img = PILImage.fromarray(item) + elif hasattr(item, "read"): + data = item.read() + img = PILImage.open(io.BytesIO(data)).convert("RGBA") + else: + return None + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + img.save(tmp.name) + return tmp.name + except Exception: + return None + + @staticmethod + def _extract_path(obj: Any) -> Optional[str]: + # Try to get a filesystem path (for mode filtering); otherwise None. + if isinstance(obj, str): + return obj + try: + import pathlib + if isinstance(obj, pathlib.Path): # type: ignore + return str(obj) + except Exception: + pass + if isinstance(obj, dict): + return obj.get("path") or obj.get("name") + for attr in ("path", "name"): + if hasattr(obj, attr): + try: + val = getattr(obj, attr) + if isinstance(val, str): + return val + except Exception: + pass + return None + + @staticmethod + def _is_image_path(p: str) -> bool: + ext = os.path.splitext(p)[1].lower() + if ext in IMAGE_EXTS: + return True + mt, _ = mimetypes.guess_type(p) + return bool(mt and mt.startswith("image/")) + + @staticmethod + def _is_video_path(p: str) -> bool: + ext = os.path.splitext(p)[1].lower() + if ext in VIDEO_EXTS: + return True + mt, _ = mimetypes.guess_type(p) + return bool(mt and mt.startswith("video/")) + + def _filter_items_by_mode(self, items: List[Any]) -> List[Any]: + # Enforce image-only or video-only collection regardless of how files were added. + out: List[Any] = [] + if self.media_mode == "image": + for it in items: + p = self._extract_path(it) + if p is None: + # No path: likely an image object added programmatically => keep + out.append(it) + elif self._is_image_path(p): + out.append(os.path.abspath(p)) + else: + for it in items: + p = self._extract_path(it) + if p is not None and self._is_video_path(p): + out.append(os.path.abspath(p)) + return out + + @staticmethod + def _concat_and_optionally_dedupe(cur: List[Any], add: List[Any]) -> List[Any]: + # Keep it simple: dedupe by path when available, else allow duplicates. + seen_paths = set() + def key(x: Any) -> Optional[str]: + if isinstance(x, str): return os.path.abspath(x) + try: + import pathlib + if isinstance(x, pathlib.Path): # type: ignore + return os.path.abspath(str(x)) + except Exception: + pass + if isinstance(x, dict): + p = x.get("path") or x.get("name") + return os.path.abspath(p) if isinstance(p, str) else None + for attr in ("path", "name"): + if hasattr(x, attr): + try: + v = getattr(x, attr) + return os.path.abspath(v) if isinstance(v, str) else None + except Exception: + pass + return None + + out: List[Any] = [] + for lst in (cur, add): + for it in lst: + k = key(it) + if k is None or k not in seen_paths: + out.append(it) + if k is not None: + seen_paths.add(k) + return out + + @staticmethod + def _paths_from_payload(payload: Any) -> List[Any]: + # Return as raw objects (paths/dicts/UploadedFile) to feed Gallery directly. + if payload is None: + return [] + if isinstance(payload, (list, tuple, set)): + return list(payload) + return [payload] + + # ---------------- event handlers ---------------- + + def _on_select(self, state: Dict[str, Any], gallery, evt: gr.SelectData) : + # Mirror the selected index into state and the gallery (server-side selected_index) + + st = get_state(state) + last_time = st.get("last_time", None) + if last_time is not None and abs(time.time()- last_time)< 0.5: # crappy trick to detect if onselect is unwanted (buggy gallery) + # print(f"ignored:{time.time()}, real {st['selected']}") + return gr.update(selected_index=st["selected"]), st + + idx = None + if evt is not None and hasattr(evt, "index"): + ix = evt.index + if isinstance(ix, int): + idx = ix + elif isinstance(ix, (tuple, list)) and ix and isinstance(ix[0], int): + if isinstance(self.columns, int) and len(ix) >= 2: + idx = ix[0] * max(1, int(self.columns)) + ix[1] + else: + idx = ix[0] + n = len(get_list(gallery)) + sel = idx if (idx is not None and 0 <= idx < n) else None + # print(f"image selected evt index:{sel}/{evt.selected}") + st["selected"] = sel + return gr.update(), st + + def _on_upload(self, value: List[Any], state: Dict[str, Any]) : + # Fires when users upload via the Gallery itself. + # items_filtered = self._filter_items_by_mode(list(value or [])) + items_filtered = list(value or []) + st = get_state(state) + new_items = self._paths_from_payload(items_filtered) + st["items"] = new_items + new_sel = len(new_items) - 1 + st["selected"] = new_sel + record_last_action(st,"add") + return gr.update(selected_index=new_sel), st + + def _on_gallery_change(self, value: List[Any], state: Dict[str, Any]) : + # Fires when users add/drag/drop/delete via the Gallery itself. + # items_filtered = self._filter_items_by_mode(list(value or [])) + items_filtered = list(value or []) + st = get_state(state) + st["items"] = items_filtered + # Keep selection if still valid, else default to last + old_sel = st.get("selected", None) + if old_sel is None or not (0 <= old_sel < len(items_filtered)): + new_sel = (len(items_filtered) - 1) if items_filtered else None + else: + new_sel = old_sel + st["selected"] = new_sel + st["last_action"] ="gallery_change" + # print(f"gallery change: set sel {new_sel}") + return gr.update(selected_index=new_sel), st + + def _on_add(self, files_payload: Any, state: Dict[str, Any], gallery): + """ + Insert added items right AFTER the currently selected index. + Keeps the same ordering as chosen in the file picker, dedupes by path, + and re-selects the last inserted item. + """ + # New items (respect image/video mode) + # new_items = self._filter_items_by_mode(self._paths_from_payload(files_payload)) + new_items = self._paths_from_payload(files_payload) + + st = get_state(state) + cur: List[Any] = get_list(gallery) + sel = st.get("selected", None) + if sel is None: + sel = (len(cur) -1) if len(cur)>0 else 0 + single = bool(st.get("single", False)) + + # Nothing to add: keep as-is + if not new_items: + return gr.update(value=cur, selected_index=st.get("selected")), st + + # Single-image mode: replace + if single: + st["items"] = [new_items[-1]] + st["selected"] = 0 + return gr.update(value=st["items"], selected_index=0), st + + # ---------- helpers ---------- + def key_of(it: Any) -> Optional[str]: + # Prefer class helper if present + if hasattr(self, "_extract_path"): + p = self._extract_path(it) # type: ignore + else: + p = it if isinstance(it, str) else None + if p is None and isinstance(it, dict): + p = it.get("path") or it.get("name") + if p is None and hasattr(it, "path"): + try: p = getattr(it, "path") + except Exception: p = None + if p is None and hasattr(it, "name"): + try: p = getattr(it, "name") + except Exception: p = None + return os.path.abspath(p) if isinstance(p, str) else None + + # Dedupe the incoming batch by path, preserve order + seen_new = set() + incoming: List[Any] = [] + for it in new_items: + k = key_of(it) + if k is None or k not in seen_new: + incoming.append(it) + if k is not None: + seen_new.add(k) + + insert_pos = min(sel, len(cur) -1) + cur_clean = cur + # Build final list and selection + merged = cur_clean[:insert_pos+1] + incoming + cur_clean[insert_pos+1:] + new_sel = insert_pos + len(incoming) # select the last inserted item + + st["items"] = merged + st["selected"] = new_sel + record_last_action(st,"add") + # print(f"gallery add: set sel {new_sel}") + return gr.update(value=merged, selected_index=new_sel), st + + def _on_remove(self, state: Dict[str, Any], gallery) : + st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) + if sel is None or not (0 <= sel < len(items)): + return gr.update(value=items, selected_index=st.get("selected")), st + items.pop(sel) + if not items: + st["items"] = []; st["selected"] = None + return gr.update(value=[], selected_index=None), st + new_sel = min(sel, len(items) - 1) + st["items"] = items; st["selected"] = new_sel + record_last_action(st,"remove") + # print(f"gallery del: new sel {new_sel}") + return gr.update(value=items, selected_index=new_sel), st + + def _on_move(self, delta: int, state: Dict[str, Any], gallery) : + st = get_state(state); items: List[Any] = get_list(gallery); sel = st.get("selected", None) + if sel is None or not (0 <= sel < len(items)): + return gr.update(value=items, selected_index=sel), st + j = sel + delta + if j < 0 or j >= len(items): + return gr.update(value=items, selected_index=sel), st + items[sel], items[j] = items[j], items[sel] + st["items"] = items; st["selected"] = j + record_last_action(st,"move") + # print(f"gallery move: set sel {j}") + return gr.update(value=items, selected_index=j), st + + def _on_clear(self, state: Dict[str, Any]) : + st = {"items": [], "selected": None, "single": get_state(state).get("single", False), "mode": self.media_mode} + record_last_action(st,"clear") + # print(f"Clear all") + return gr.update(value=[], selected_index=None), st + + def _on_toggle_single(self, to_single: bool, state: Dict[str, Any]) : + st = get_state(state); st["single"] = bool(to_single) + items: List[Any] = list(st["items"]); sel = st.get("selected", None) + if st["single"]: + keep = items[sel] if (sel is not None and 0 <= sel < len(items)) else (items[-1] if items else None) + items = [keep] if keep is not None else [] + sel = 0 if items else None + st["items"] = items; st["selected"] = sel + + upload_update = gr.update(file_count=("single" if st["single"] else "multiple")) + left_update = gr.update(visible=not st["single"]) + right_update = gr.update(visible=not st["single"]) + clear_update = gr.update(visible=not st["single"]) + gallery_update= gr.update(value=items, selected_index=sel) + + return upload_update, left_update, right_update, clear_update, gallery_update, st + + # ---------------- build & wire ---------------- + + def mount(self, parent: Optional[gr.Blocks | gr.Group | gr.Row | gr.Column] = None, update_form = False): + if parent is not None: + with parent: + col = self._build_ui(update_form) + else: + col = self._build_ui(update_form) + if not update_form: + self._wire_events() + return col + + def _build_ui(self, update = False) -> gr.Column: + with gr.Column(elem_id=self.elem_id, elem_classes=self.elem_classes) as col: + self.container = col + + self.state = gr.State(dict(self._initial_state)) + + if update: + self.gallery = gr.update( + value=self._initial_state["items"], + selected_index=self._initial_state["selected"], # server-side selection + label=self.label, + show_label=self.show_label, + ) + else: + self.gallery = gr.Gallery( + value=self._initial_state["items"], + label=self.label, + height=self.height, + columns=self.columns, + show_label=self.show_label, + preview= True, + # type="pil", # very slow + file_types= list(IMAGE_EXTS) if self.media_mode == "image" else list(VIDEO_EXTS), + selected_index=self._initial_state["selected"], # server-side selection + ) + + # One-line controls + exts = sorted(IMAGE_EXTS if self.media_mode == "image" else VIDEO_EXTS) if self.accept_filter else None + with gr.Row(equal_height=True, elem_classes=["amg-controls"]): + self.upload_btn = gr.UploadButton( + "Set" if self._initial_state["single"] else "Add", + file_types=exts, + file_count=("single" if self._initial_state["single"] else "multiple"), + variant="primary", + size="sm", + min_width=1, + ) + self.btn_remove = gr.Button(" Remove ", size="sm", min_width=1) + self.btn_left = gr.Button("◀ Left", size="sm", visible=not self._initial_state["single"], min_width=1) + self.btn_right = gr.Button("Right ▶", size="sm", visible=not self._initial_state["single"], min_width=1) + self.btn_clear = gr.Button(" Clear ", variant="secondary", size="sm", visible=not self._initial_state["single"], min_width=1) + + return col + + def _wire_events(self): + # Selection: mirror into state and keep gallery.selected_index in sync + self.gallery.select( + self._on_select, + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) + self.gallery.upload( + self._on_upload, + inputs=[self.gallery, self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Gallery value changed by user actions (click-to-add, drag-drop, internal remove, etc.) + self.gallery.upload( + self._on_gallery_change, + inputs=[self.gallery, self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Add via UploadButton + self.upload_btn.upload( + self._on_add, + inputs=[self.upload_btn, self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Remove selected + self.btn_remove.click( + self._on_remove, + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Reorder using selected index, keep same item selected + self.btn_left.click( + lambda st, gallery: self._on_move(-1, st, gallery), + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + self.btn_right.click( + lambda st, gallery: self._on_move(+1, st, gallery), + inputs=[self.state, self.gallery], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # Clear all + self.btn_clear.click( + self._on_clear, + inputs=[self.state], + outputs=[self.gallery, self.state], + trigger_mode="always_last", + ) + + # ---------------- public API ---------------- + + def set_one_image_mode(self, enabled: bool = True): + """Toggle single-image mode at runtime.""" + return ( + self._on_toggle_single, + [gr.State(enabled), self.state], + [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state], + ) + + def get_toggable_elements(self): + return [self.upload_btn, self.btn_left, self.btn_right, self.btn_clear, self.gallery, self.state] + +# import gradio as gr + +# with gr.Blocks() as demo: +# amg = AdvancedMediaGallery(media_mode="image", height=190, columns=8) +# amg.mount() +# g = amg.gallery +# # buttons to switch modes live (optional) +# def process(g): +# pass +# with gr.Row(): +# gr.Button("toto").click(process, g) +# gr.Button("ONE image").click(*amg.set_one_image_mode(True)) +# gr.Button("MULTI image").click(*amg.set_one_image_mode(False)) + +# demo.launch() diff --git a/shared/inpainting/__init__.py b/shared/inpainting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/shared/inpainting/lanpaint.py b/shared/inpainting/lanpaint.py new file mode 100644 index 000000000..3165e7b01 --- /dev/null +++ b/shared/inpainting/lanpaint.py @@ -0,0 +1,240 @@ +import torch +from .utils import * +from functools import partial + +# Many thanks to the LanPaint team for this implementation (https://github.com/scraed/LanPaint/) + +def _pack_latents(latents): + batch_size, num_channels_latents, _, height, width = latents.shape + + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + +def _unpack_latents(latents, height, width, vae_scale_factor=8): + batch_size, num_patches, channels = latents.shape + + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width) + + return latents + +class LanPaint(): + def __init__(self, NSteps = 5, Friction = 15, Lambda = 8, Beta = 1, StepSize = 0.15, IS_FLUX = True, IS_FLOW = False): + self.n_steps = NSteps + self.chara_lamb = Lambda + self.IS_FLUX = IS_FLUX + self.IS_FLOW = IS_FLOW + self.step_size = StepSize + self.friction = Friction + self.chara_beta = Beta + self.img_dim_size = None + def add_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (None,) * (self.img_dim_size-1) + return array[index] + def remove_none_dims(self, array): + # Create a tuple with ':' for the first dimension and 'None' repeated num_nones times + index = (slice(None),) + (0,) * (self.img_dim_size-1) + return array[index] + def __call__(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, latent_image, noise, sigma, latent_mask, n_steps=None, height =720, width = 1280, vae_scale_factor = 8): + latent_image = _unpack_latents(latent_image, height=height, width=width, vae_scale_factor=vae_scale_factor) + noise = _unpack_latents(noise, height=height, width=width, vae_scale_factor=vae_scale_factor) + x = _unpack_latents(x, height=height, width=width, vae_scale_factor=vae_scale_factor) + latent_mask = _unpack_latents(latent_mask, height=height, width=width, vae_scale_factor=vae_scale_factor) + self.height = height + self.width = width + self.vae_scale_factor = vae_scale_factor + self.img_dim_size = len(x.shape) + self.latent_image = latent_image + self.noise = noise + if n_steps is None: + n_steps = self.n_steps + out = self.LanPaint(denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, self.IS_FLUX, self.IS_FLOW) + out = _pack_latents(out) + return out + def LanPaint(self, denoise, cfg_predictions, true_cfg_scale, cfg_BIG, x, sigma, latent_mask, n_steps, IS_FLUX, IS_FLOW): + if IS_FLUX: + cfg_BIG = 1.0 + + def double_denoise(latents, t): + latents = _pack_latents(latents) + noise_pred, neg_noise_pred = denoise(latents, true_cfg_scale) + if noise_pred == None: return None, None + predict_std = cfg_predictions(noise_pred, neg_noise_pred, true_cfg_scale, t) + predict_std = _unpack_latents(predict_std, self.height, self.width, self.vae_scale_factor) + if true_cfg_scale == cfg_BIG: + predict_big = predict_std + else: + predict_big = cfg_predictions(noise_pred, neg_noise_pred, cfg_BIG, t) + predict_big = _unpack_latents(predict_big, self.height, self.width, self.vae_scale_factor) + return predict_std, predict_big + + if len(sigma.shape) == 0: + sigma = torch.tensor([sigma.item()]) + latent_mask = 1 - latent_mask + if IS_FLUX or IS_FLOW: + Flow_t = sigma + abt = (1 - Flow_t)**2 / ((1 - Flow_t)**2 + Flow_t**2 ) + VE_Sigma = Flow_t / (1 - Flow_t) + #print("t", torch.mean( sigma ).item(), "VE_Sigma", torch.mean( VE_Sigma ).item()) + else: + VE_Sigma = sigma + abt = 1/( 1+VE_Sigma**2 ) + Flow_t = (1-abt)**0.5 / ( (1-abt)**0.5 + abt**0.5 ) + # VE_Sigma, abt, Flow_t = current_times + current_times = (VE_Sigma, abt, Flow_t) + + step_size = self.step_size * (1 - abt) + step_size = self.add_none_dims(step_size) + # self.inner_model.inner_model.scale_latent_inpaint returns variance exploding x_t values + # This is the replace step + # x = x * (1 - latent_mask) + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image)* latent_mask + + noisy_image = self.latent_image * (1.0 - sigma) + self.noise * sigma + x = x * (1 - latent_mask) + noisy_image * latent_mask + + if IS_FLUX or IS_FLOW: + x_t = x * ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x_t = x / ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + + ############ LanPaint Iterations Start ############### + # after noise_scaling, noise = latent_image + noise * sigma, which is x_t in the variance exploding diffusion model notation for the known region. + args = None + for i in range(n_steps): + score_func = partial( self.score_model, y = self.latent_image, mask = latent_mask, abt = self.add_none_dims(abt), sigma = self.add_none_dims(VE_Sigma), tflow = self.add_none_dims(Flow_t), denoise_func = double_denoise ) + if score_func is None: return None + x_t, args = self.langevin_dynamics(x_t, score_func , latent_mask, step_size , current_times, sigma_x = self.add_none_dims(self.sigma_x(abt)), sigma_y = self.add_none_dims(self.sigma_y(abt)), args = args) + if IS_FLUX or IS_FLOW: + x = x_t / ( self.add_none_dims(abt)**0.5 + (1-self.add_none_dims(abt))**0.5 ) + else: + x = x_t * ( 1+self.add_none_dims(VE_Sigma)**2 )**0.5 # switch to variance perserving x_t values + ############ LanPaint Iterations End ############### + # out is x_0 + # out, _ = self.inner_model(x, sigma, model_options=model_options, seed=seed) + # out = out * (1-latent_mask) + self.latent_image * latent_mask + # return out + return x + + def score_model(self, x_t, y, mask, abt, sigma, tflow, denoise_func): + + lamb = self.chara_lamb + if self.IS_FLUX or self.IS_FLOW: + # compute t for flow model, with a small epsilon compensating for numerical error. + x = x_t / ( abt**0.5 + (1-abt)**0.5 ) # switch to Gaussian flow matching + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(tflow)) + if x_0 is None: return None + else: + x = x_t * ( 1+sigma**2 )**0.5 # switch to variance exploding + x_0, x_0_BIG = denoise_func(x, self.remove_none_dims(sigma)) + if x_0 is None: return None + + score_x = -(x_t - x_0) + score_y = - (1 + lamb) * ( x_t - y ) + lamb * (x_t - x_0_BIG) + return score_x * (1 - mask) + score_y * mask + def sigma_x(self, abt): + # the time scale for the x_t update + return abt**0 + def sigma_y(self, abt): + beta = self.chara_beta * abt ** 0 + return beta + + def langevin_dynamics(self, x_t, score, mask, step_size, current_times, sigma_x=1, sigma_y=0, args=None): + # prepare the step size and time parameters + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + step_sizes = self.prepare_step_size(current_times, step_size, sigma_x, sigma_y) + sigma, abt, dtx, dty, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y = step_sizes + # print('mask',mask.device) + if torch.mean(dtx) <= 0.: + return x_t, args + # ------------------------------------------------------------------------- + # Compute the Langevin dynamics update in variance perserving notation + # ------------------------------------------------------------------------- + #x0 = self.x0_evalutation(x_t, score, sigma, args) + #C = abt**0.5 * x0 / (1-abt) + A = A_x * (1-mask) + A_y * mask + D = D_x * (1-mask) + D_y * mask + dt = dtx * (1-mask) + dty * mask + Gamma = Gamma_x * (1-mask) + Gamma_y * mask + + + def Coef_C(x_t): + x0 = self.x0_evalutation(x_t, score, sigma, args) + C = (abt**0.5 * x0 - x_t )/ (1-abt) + A * x_t + return C + def advance_time(x_t, v, dt, Gamma, A, C, D): + dtype = x_t.dtype + with torch.autocast(device_type=x_t.device.type, dtype=torch.float32): + osc = StochasticHarmonicOscillator(Gamma, A, C, D ) + x_t, v = osc.dynamics(x_t, v, dt ) + x_t = x_t.to(dtype) + v = v.to(dtype) + return x_t, v + if args is None: + #v = torch.zeros_like(x_t) + v = None + C = Coef_C(x_t) + #print(torch.squeeze(dtx), torch.squeeze(dty)) + x_t, v = advance_time(x_t, v, dt, Gamma, A, C, D) + else: + v, C = args + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C_new = Coef_C(x_t) + v = v + Gamma**0.5 * ( C_new - C) *dt + + x_t, v = advance_time(x_t, v, dt/2, Gamma, A, C, D) + + C = C_new + + return x_t, (v, C) + + def prepare_step_size(self, current_times, step_size, sigma_x, sigma_y): + # ------------------------------------------------------------------------- + # Unpack current times parameters (sigma and abt) + sigma, abt, flow_t = current_times + sigma = self.add_none_dims(sigma) + abt = self.add_none_dims(abt) + # Compute time step (dtx, dty) for x and y branches. + dtx = 2 * step_size * sigma_x + dty = 2 * step_size * sigma_y + + # ------------------------------------------------------------------------- + # Define friction parameter Gamma_hat for each branch. + # Using dtx**0 provides a tensor of the proper device/dtype. + + Gamma_hat_x = self.friction **2 * self.step_size * sigma_x / 0.1 * sigma**0 + Gamma_hat_y = self.friction **2 * self.step_size * sigma_y / 0.1 * sigma**0 + #print("Gamma_hat_x", torch.mean(Gamma_hat_x).item(), "Gamma_hat_y", torch.mean(Gamma_hat_y).item()) + # adjust dt to match denoise-addnoise steps sizes + Gamma_hat_x /= 2. + Gamma_hat_y /= 2. + A_t_x = (1) / ( 1 - abt ) * dtx / 2 + A_t_y = (1+self.chara_lamb) / ( 1 - abt ) * dty / 2 + + + A_x = A_t_x / (dtx/2) + A_y = A_t_y / (dty/2) + Gamma_x = Gamma_hat_x / (dtx/2) + Gamma_y = Gamma_hat_y / (dty/2) + + #D_x = (2 * (1 + sigma**2) )**0.5 + #D_y = (2 * (1 + sigma**2) )**0.5 + D_x = (2 * abt**0 )**0.5 + D_y = (2 * abt**0 )**0.5 + return sigma, abt, dtx/2, dty/2, Gamma_x, Gamma_y, A_x, A_y, D_x, D_y + + + + def x0_evalutation(self, x_t, score, sigma, args): + x0 = x_t + score(x_t) + return x0 \ No newline at end of file diff --git a/shared/inpainting/utils.py b/shared/inpainting/utils.py new file mode 100644 index 000000000..c017ab0a8 --- /dev/null +++ b/shared/inpainting/utils.py @@ -0,0 +1,301 @@ +import torch +def epxm1_x(x): + # Compute the (exp(x) - 1) / x term with a small value to avoid division by zero. + result = torch.special.expm1(x) / x + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x) < 1e-2 + result = torch.where(mask, 1 + x/2. + x**2 / 6., result) + return result +def epxm1mx_x2(x): + # Compute the (exp(x) - 1 - x) / x**2 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x) / x**2 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**2) < 1e-2 + result = torch.where(mask, 1/2. + x/6 + x**2 / 24 + x**3 / 120, result) + return result + +def expm1mxmhx2_x3(x): + # Compute the (exp(x) - 1 - x - x**2 / 2) / x**3 term with a small value to avoid division by zero. + result = (torch.special.expm1(x) - x - x**2 / 2) / x**3 + # replace NaN or inf values with 0 + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + mask = torch.abs(x**3) < 1e-2 + result = torch.where(mask, 1/6 + x/24 + x**2 / 120 + x**3 / 720 + x**4 / 5040, result) + return result + +def exp_1mcosh_GD(gamma_t, delta): + """ + Compute e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = torch.exp(-gamma_t) - (torch.exp(gamma_t * (sqrt_abs_delta - 1)) + torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + numerator_neg = torch.exp(-gamma_t) * ( 1 - torch.cos(gamma_t * sqrt_abs_delta ) ) + numerator = torch.where(is_positive, numerator_pos, numerator_neg) + result = numerator / (delta * gamma_t**2 ) + # Handle NaN/inf cases + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + # Handle numerical instability for small delta + mask = torch.abs(gamma_t_sqrt_delta**2) < 5e-2 + taylor = ( -0.5 - gamma_t**2 / 24 * delta - gamma_t**4 / 720 * delta**2 ) * torch.exp(-gamma_t) + result = torch.where(mask, taylor, result) + return result + +def exp_sinh_GsqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + # Main computation + is_positive = delta > 0 + sqrt_abs_delta = torch.sqrt(torch.abs(delta)) + gamma_t_sqrt_delta = gamma_t * sqrt_abs_delta + numerator_pos = (torch.exp(gamma_t * (sqrt_abs_delta - 1)) - torch.exp(gamma_t * (-sqrt_abs_delta - 1))) / 2 + denominator_pos = gamma_t_sqrt_delta + result_pos = numerator_pos / gamma_t_sqrt_delta + result_pos = torch.where(torch.isfinite(result_pos), result_pos, torch.zeros_like(result_pos)) + + # Taylor expansion for small gamma_t_sqrt_delta + mask = torch.abs(gamma_t_sqrt_delta) < 1e-2 + taylor = ( 1 + gamma_t**2 / 6 * delta + gamma_t**4 / 120 * delta**2 ) * torch.exp(-gamma_t) + result_pos = torch.where(mask, taylor, result_pos) + + # Handle negative delta + result_neg = torch.exp(-gamma_t) * torch.special.sinc(gamma_t_sqrt_delta/torch.pi) + result = torch.where(is_positive, result_pos, result_neg) + return result + +def exp_cosh(gamma_t, delta): + """ + Compute e^(-Γt) * cosh(Γt√Δ) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_1mcosh_GD_result = exp_1mcosh_GD(gamma_t, delta) # e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + result = torch.exp(-gamma_t) - gamma_t**2 * delta * exp_1mcosh_GD_result + return result +def exp_sinh_sqrtD(gamma_t, delta): + """ + Compute e^(-Γt) * sinh(Γt√Δ) / √Δ + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + Returns: + Result of the computation with numerical stability handling + """ + exp_sinh_GsqrtD_result = exp_sinh_GsqrtD(gamma_t, delta) # e^(-Γt) * sinh(Γt√Δ) / (Γt√Δ) + result = gamma_t * exp_sinh_GsqrtD_result + return result + + + +def zeta1(gamma_t, delta): + # Compute hyperbolic terms and exponential + half_gamma_t = gamma_t / 2 + exp_cosh_term = exp_cosh(half_gamma_t, delta) + exp_sinh_term = exp_sinh_sqrtD(half_gamma_t, delta) + + + # Main computation + numerator = 1 - (exp_cosh_term + exp_sinh_term) + denominator = gamma_t * (1 - delta) / 4 + result = 1 - numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small x (similar to your epxm1Dx approach) + mask = torch.abs(denominator) < 5e-3 + term1 = epxm1_x(-gamma_t) + term2 = epxm1mx_x2(-gamma_t) + term3 = expm1mxmhx2_x3(-gamma_t) + taylor = term1 + (1/2.+ term1-3*term2)*denominator + (-1/6. + term1/2 - 4 * term2 + 10 * term3) * denominator**2 + result = torch.where(mask, taylor, result) + + return result + +def exp_cosh_minus_terms(gamma_t, delta): + """ + Compute E^(-tΓ) * (Cosh[tΓ] - 1 - (Cosh[tΓ√Δ] - 1)/Δ) / (tΓ(1 - Δ)) + + Parameters: + gamma_t: Γ*t term (could be a scalar or tensor) + delta: Δ term (could be a scalar or tensor) + + Returns: + Result of the computation with numerical stability handling + """ + exp_term = torch.exp(-gamma_t) + # Compute individual terms + exp_cosh_term = exp_cosh(gamma_t, gamma_t**0) - exp_term # E^(-tΓ) (Cosh[tΓ] - 1) term + exp_cosh_delta_term = - gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) # E^(-tΓ) (Cosh[tΓ√Δ] - 1)/Δ term + + #exp_1mcosh_GD e^(-Γt) * (1 - cosh(Γt√Δ))/ ( (Γt)**2 Δ ) + # Main computation + numerator = exp_cosh_term - exp_cosh_delta_term + denominator = gamma_t * (1 - delta) + + result = numerator / denominator + + # Handle numerical instability + result = torch.where(torch.isfinite(result), result, torch.zeros_like(result)) + + # Taylor expansion for small gamma_t and delta near 1 + mask = (torch.abs(denominator) < 1e-1) + exp_1mcosh_GD_term = exp_1mcosh_GD(gamma_t, delta**0) + taylor = ( + gamma_t*exp_1mcosh_GD_term + 0.5 * gamma_t * exp_sinh_GsqrtD(gamma_t, delta**0) + - denominator / 4 * ( 0.5 * exp_cosh(gamma_t, delta**0) - 4 * exp_1mcosh_GD_term - 5 /2 * exp_sinh_GsqrtD(gamma_t, delta**0) ) + ) + result = torch.where(mask, taylor, result) + + return result + + +def zeta2(gamma_t, delta): + half_gamma_t = gamma_t / 2 + return exp_sinh_GsqrtD(half_gamma_t, delta) + +def sig11(gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + + +def Zcoefs(gamma_t, delta): + Zeta1 = zeta1(gamma_t, delta) + Zeta2 = zeta2(gamma_t, delta) + + sq_total = 1 - Zeta1 + gamma_t * (delta - 1) * (Zeta1 - 1)**2 / 8 + amplitude = torch.sqrt(sq_total) + Zcoef1 = ( gamma_t**0.5 * Zeta2 / 2 **0.5 ) / amplitude + Zcoef2 = Zcoef1 * gamma_t *( - 2 * exp_1mcosh_GD(gamma_t, delta) / sig11(gamma_t, delta) ) ** 0.5 + #cterm = exp_cosh_minus_terms(gamma_t, delta) + #sterm = exp_sinh_sqrtD(gamma_t, delta**0) + exp_sinh_sqrtD(gamma_t, delta) + #Zcoef3 = 2 * torch.sqrt( cterm / ( gamma_t * (1 - delta) * cterm + sterm ) ) + Zcoef3 = torch.sqrt( torch.maximum(1 - Zcoef1**2 - Zcoef2**2, sq_total.new_zeros(sq_total.shape)) ) + + return Zcoef1 * amplitude, Zcoef2 * amplitude, Zcoef3 * amplitude, amplitude + +def Zcoefs_asymp(gamma_t, delta): + A_t = (gamma_t * (1 - delta) )/4 + return epxm1_x(- 2 * A_t) + +class StochasticHarmonicOscillator: + """ + Simulates a stochastic harmonic oscillator governed by the equations: + dy(t) = q(t) dt + dq(t) = -Γ A y(t) dt + Γ C dt + Γ D dw(t) - Γ q(t) dt + + Also define v(t) = q(t) / √Γ, which is numerically more stable. + + Where: + y(t) - Position variable + q(t) - Velocity variable + Γ - Damping coefficient + A - Harmonic potential strength + C - Constant force term + D - Noise amplitude + dw(t) - Wiener process (Brownian motion) + """ + def __init__(self, Gamma, A, C, D): + self.Gamma = Gamma + self.A = A + self.C = C + self.D = D + self.Delta = 1 - 4 * A / Gamma + def sig11(self, gamma_t, delta): + return 1 - torch.exp(-gamma_t) + gamma_t**2 * exp_1mcosh_GD(gamma_t, delta) + exp_sinh_sqrtD(gamma_t, delta) + def sig22(self, gamma_t, delta): + return 1- zeta1(2*gamma_t, delta) + 2 * gamma_t * exp_1mcosh_GD(gamma_t, delta) + def dynamics(self, y0, v0, t): + """ + Calculates the position and velocity variables at time t. + + Parameters: + y0 (float): Initial position + v0 (float): Initial velocity v(0) = q(0) / √Γ + t (float): Time at which to evaluate the dynamics + Returns: + tuple: (y(t), v(t)) + """ + + dummyzero = y0.new_zeros(1) # convert scalar to tensor with same device and dtype as y0 + Delta = self.Delta + dummyzero + Gamma_hat = self.Gamma * t + dummyzero + A = self.A + dummyzero + C = self.C + dummyzero + D = self.D + dummyzero + Gamma = self.Gamma + dummyzero + zeta_1 = zeta1( Gamma_hat, Delta) + zeta_2 = zeta2( Gamma_hat, Delta) + EE = 1 - Gamma_hat * zeta_2 + + if v0 is None: + v0 = torch.randn_like(y0) * D / 2 ** 0.5 + #v0 = (C - A * y0)/Gamma**0.5 + + # Calculate mean position and velocity + term1 = (1 - zeta_1) * (C * t - A * t * y0) + zeta_2 * (Gamma ** 0.5) * v0 * t + y_mean = term1 + y0 + v_mean = (1 - EE)*(C - A * y0) / (Gamma ** 0.5) + (EE - A * t * (1 - zeta_1)) * v0 + + cov_yy = D**2 * t * self.sig22(Gamma_hat, Delta) + cov_vv = D**2 * self.sig11(Gamma_hat, Delta) / 2 + cov_yv = (zeta2(Gamma_hat, Delta) * Gamma_hat * D ) **2 / 2 / (Gamma ** 0.5) + + # sample new position and velocity with multivariate normal distribution + + batch_shape = y0.shape + cov_matrix = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + cov_matrix[..., 0, 0] = cov_yy + cov_matrix[..., 0, 1] = cov_yv + cov_matrix[..., 1, 0] = cov_yv # symmetric + cov_matrix[..., 1, 1] = cov_vv + + + + # Compute the Cholesky decomposition to get scale_tril + #scale_tril = torch.linalg.cholesky(cov_matrix) + scale_tril = torch.zeros(*batch_shape, 2, 2, device=y0.device, dtype=y0.dtype) + tol = 1e-8 + cov_yy = torch.clamp( cov_yy, min = tol ) + sd_yy = torch.sqrt( cov_yy ) + inv_sd_yy = 1/(sd_yy) + + scale_tril[..., 0, 0] = sd_yy + scale_tril[..., 0, 1] = 0. + scale_tril[..., 1, 0] = cov_yv * inv_sd_yy + scale_tril[..., 1, 1] = torch.clamp( cov_vv - cov_yv**2 / cov_yy, min = tol ) ** 0.5 + # check if it matches torch.linalg. + #assert torch.allclose(torch.linalg.cholesky(cov_matrix), scale_tril, atol = 1e-4, rtol = 1e-4 ) + # Sample correlated noise from multivariate normal + mean = torch.zeros(*batch_shape, 2, device=y0.device, dtype=y0.dtype) + mean[..., 0] = y_mean + mean[..., 1] = v_mean + new_yv = torch.distributions.MultivariateNormal( + loc=mean, + scale_tril=scale_tril + ).sample() + + return new_yv[...,0], new_yv[...,1] \ No newline at end of file diff --git a/shared/radial_attention/attention.py b/shared/radial_attention/attention.py new file mode 100644 index 000000000..ab4a5a293 --- /dev/null +++ b/shared/radial_attention/attention.py @@ -0,0 +1,49 @@ +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from einops import rearrange +from .attn_mask import RadialAttention, MaskMap + +def fill_radial_cache(radial_cache, nb_layers, lat_t, lat_h, lat_w): + MaskMap._log_mask = None + + for i in range(nb_layers): + radial_cache[i] = WanSparseAttnProcessor2_0(i, lat_t, lat_h, lat_w) + +class WanSparseAttnProcessor2_0: + mask_map = None + dense_timestep = 0 + dense_block = 0 + decay_factor = 1.0 + sparse_type = "radial" # default to radial attention, can be changed to "dense" for dense attention + use_sage_attention = True + + def __init__(self, layer_idx, lat_t, lat_h, lat_w): + self.layer_idx = layer_idx + self.mask_map = MaskMap(video_token_num=lat_t * lat_h * lat_w // 4 , num_frame=lat_t) + def __call__( + self, + qkv_list, + timestep_no = 0, + ) -> torch.Tensor: + query, key, value = qkv_list + + batch_size = query.shape[0] + # transform (batch_size, seq_len, num_heads, head_dim) to (seq_len * batch_size, num_heads, head_dim) + query = rearrange(query, "b s h d -> (b s) h d") + key = rearrange(key, "b s h d -> (b s) h d") + value = rearrange(value, "b s h d -> (b s) h d") + if timestep_no < self.dense_timestep or self.layer_idx < self.dense_block or self.sparse_type == "dense": + hidden_states = RadialAttention( + query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="dense", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention + ) + else: + # apply radial attention + hidden_states = RadialAttention( + query=query, key=key, value=value, mask_map=self.mask_map, sparsity_type="radial", block_size=128, decay_factor=self.decay_factor, model_type="wan", pre_defined_mask=None, use_sage_attention=self.use_sage_attention + ) + # transform back to (batch_size, num_heads, seq_len, head_dim) + hidden_states = rearrange(hidden_states, "(b s) h d -> b s h d", b=batch_size) + + return hidden_states diff --git a/shared/radial_attention/attn_mask.py b/shared/radial_attention/attn_mask.py new file mode 100644 index 000000000..4c29f1c32 --- /dev/null +++ b/shared/radial_attention/attn_mask.py @@ -0,0 +1,379 @@ +import torch +# import flashinfer +import matplotlib.pyplot as plt +# from sparse_sageattn import sparse_sageattn +from einops import rearrange, repeat +from sageattention import sageattn +from spas_sage_attn import block_sparse_sage2_attn_cuda + +def get_cuda_arch_versions(): + cuda_archs = [] + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") + return cuda_archs + +from spas_sage_attn import block_sparse_sage2_attn_cuda + +def sparge_mask_convert(mask: torch.Tensor, block_size: int = 128, arch="sm") -> torch.Tensor: + assert block_size in [128, 64], "Radial Attention only supports block size of 128 or 64" + assert mask.shape[0] == mask.shape[1], "Input mask must be square." + + if block_size == 128: + if arch == "sm90": + new_mask = torch.repeat_interleave(mask, 2, dim=0) + else: + new_mask = torch.repeat_interleave(mask, 2, dim=1) + + elif block_size == 64: + if arch == "sm90": + num_row, num_col = mask.shape + reshaped_mask = mask.view(num_row, num_col // 2, 2) + new_mask = torch.max(reshaped_mask, dim=2).values + else: + num_row, num_col = mask.shape + reshaped_mask = mask.view(num_row // 2, 2, num_col) + new_mask = torch.max(reshaped_mask, dim=1).values + + return new_mask + +def get_indptr_from_mask(mask, query): + # query shows the device of the indptr + # indptr (torch.Tensor) - the block index pointer of the block-sparse matrix on row dimension, + # shape `(MB + 1,)`, where `MB` is the number of blocks in the row dimension. + # The first element is always 0, and the last element is the number of blocks in the row dimension. + # The rest of the elements are the number of blocks in each row. + # the mask is already a block sparse mask + indptr = torch.zeros(mask.shape[0] + 1, device=query.device, dtype=torch.int32) + indptr[0] = 0 + row_counts = mask.sum(dim=1).flatten() # Ensure 1D output [num_blocks_row] + indptr[1:] = torch.cumsum(row_counts, dim=0) + return indptr + +def get_indices_from_mask(mask, query): + # indices (torch.Tensor) - the block indices of the block-sparse matrix on column dimension, + # shape `(nnz,),` where `nnz` is the number of non-zero blocks. + # The elements in `indices` array should be less than `NB`: the number of blocks in the column dimension. + nonzero_indices = torch.nonzero(mask) + indices = nonzero_indices[:, 1].to(dtype=torch.int32, device=query.device) + return indices + +def shrinkMaskStrict(mask, block_size=128): + seqlen = mask.shape[0] + block_num = seqlen // block_size + mask = mask[:block_num * block_size, :block_num * block_size].view(block_num, block_size, block_num, block_size) + col_densities = mask.sum(dim = 1) / block_size + # we want the minimum non-zero column density in the block + non_zero_densities = col_densities > 0 + high_density_cols = col_densities > 1/3 + frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9) + block_mask = frac_high_density_cols > 0.6 + block_mask[0:0] = True + block_mask[-1:-1] = True + return block_mask + +def pad_qkv(input_tensor, block_size=128): + """ + Pad the input tensor to be a multiple of the block size. + input shape: (seqlen, num_heads, hidden_dim) + """ + seqlen, num_heads, hidden_dim = input_tensor.shape + # Calculate the necessary padding + padding_length = (block_size - (seqlen % block_size)) % block_size + # Create a padded tensor with zeros + padded_tensor = torch.zeros((seqlen + padding_length, num_heads, hidden_dim), device=input_tensor.device, dtype=input_tensor.dtype) + # Copy the original tensor into the padded tensor + padded_tensor[:seqlen, :, :] = input_tensor + + return padded_tensor + +def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query): + assert(sparse_type in ["radial"]) + dist = abs(i - j) + group = dist.bit_length() + threshold = 128 # hardcoded threshold for now, which is equal to block-size + decay_length = 2 ** token_per_frame.bit_length() / 2 ** group + if decay_length >= threshold: + return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + + split_factor = int(threshold / decay_length) + modular = dist % split_factor + if modular == 0: + return torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + else: + return torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + +def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None): + assert(sparse_type in ["radial"]) + dist = abs(i - j) + if model_type == "wan": + if dist < 1: + return token_per_frame + if dist == 1: + return token_per_frame // 2 + elif model_type == "hunyuan": + if dist <= 1: + return token_per_frame + else: + raise ValueError(f"Unknown model type: {model_type}") + group = dist.bit_length() + decay_length = 2 ** token_per_frame.bit_length() / 2 ** group * decay_factor + threshold = block_size + if decay_length >= threshold: + return decay_length + else: + return threshold + +def gen_log_mask_shrinked(query, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None): + """ + A more memory friendly version, we generate the attention mask of each frame pair at a time, + shrinks it, and stores it into the final result + """ + final_log_mask = torch.zeros((s // block_size, s // block_size), device=query.device, dtype=torch.bool) + token_per_frame = video_token_num // num_frame + video_text_border = video_token_num // block_size + + col_indices = torch.arange(0, token_per_frame, device=query.device).view(1, -1) + row_indices = torch.arange(0, token_per_frame, device=query.device).view(-1, 1) + final_log_mask[video_text_border:] = True + final_log_mask[:, video_text_border:] = True + for i in range(num_frame): + for j in range(num_frame): + local_mask = torch.zeros((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + if j == 0 and model_type == "wan": # this is attention sink + local_mask = torch.ones((token_per_frame, token_per_frame), device=query.device, dtype=torch.bool) + else: + window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type) + local_mask = torch.abs(col_indices - row_indices) <= window_width + split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, query) + local_mask = torch.logical_and(local_mask, split_mask) + + remainder_row = (i * token_per_frame) % block_size + remainder_col = (j * token_per_frame) % block_size + # get the padded size + all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size + all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size + padded_local_mask = torch.zeros((all_length_row, all_length_col), device=query.device, dtype=torch.bool) + padded_local_mask[remainder_row:remainder_row + token_per_frame, remainder_col:remainder_col + token_per_frame] = local_mask + # shrink the mask + block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size) + # set the block mask to the final log mask + block_row_start = (i * token_per_frame) // block_size + block_col_start = (j * token_per_frame) // block_size + block_row_end = block_row_start + block_mask.shape[0] + block_col_end = block_col_start + block_mask.shape[1] + final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or( + final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask) + print(f"mask sparsity: {1 - final_log_mask.sum() / final_log_mask.numel()}") + return final_log_mask + +class MaskMap: + _log_mask = None + + def __init__(self, video_token_num=25440, num_frame=16): + self.video_token_num = video_token_num + self.num_frame = num_frame + + def queryLogMask(self, query, sparse_type, block_size=128, decay_factor=0.5, model_type=None): + if MaskMap._log_mask is None: + MaskMap._log_mask = torch.ones((query.shape[0] // block_size, query.shape[0] // block_size), device=query.device, dtype=torch.bool) + MaskMap._log_mask = gen_log_mask_shrinked(query, query.shape[0], self.video_token_num, self.num_frame, sparse_type=sparse_type, decay_factor=decay_factor, model_type=model_type, block_size=block_size) + return MaskMap._log_mask + +def SpargeSageAttnBackend(query, key, value, mask_map=None, video_mask=None, pre_defined_mask=None, block_size=128): + if video_mask.all(): + # dense case + kv_border = pre_defined_mask[0].sum() if pre_defined_mask is not None else key.shape[0] + output_video = sageattn( + query[:mask_map.video_token_num, :, :].unsqueeze(0), + key[:kv_border, :, :].unsqueeze(0), + value[:kv_border, :, :].unsqueeze(0), + tensor_layout="NHD", + )[0] + + if pre_defined_mask is not None: + output_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key[:pre_defined_mask[0].sum(), :, :], + v=value[:pre_defined_mask[0].sum(), :, :], + causal=False, + return_lse=False, + ) + return torch.cat([output_video, output_text], dim=0) + else: + return output_video + + # sparse-sageattention only supports (b, h, s, d) layout, need rearrange first + query_hnd = rearrange(query.unsqueeze(0), "b s h d -> b h s d") + key_hnd = rearrange(key.unsqueeze(0), "b s h d -> b h s d") + value_hnd = rearrange(value.unsqueeze(0), "b s h d -> b h s d") + arch = get_cuda_arch_versions()[query.device.index] + converted_mask = repeat(sparge_mask_convert(mask=video_mask, block_size=block_size, arch=arch), "s t -> b h s t", b=query_hnd.shape[0], h=query_hnd.shape[1]) + + converted_mask = converted_mask.to(torch.int8) + if pre_defined_mask is None: + # wan case + output = block_sparse_sage2_attn_cuda( + query_hnd[:, :, :mask_map.video_token_num, :], + key_hnd[:, :, :mask_map.video_token_num, :], + value_hnd[:, :, :mask_map.video_token_num, :], + mask_id=converted_mask, + tensor_layout="HND", + ) + + # rearrange back to (s, h, d), we know that b = 1 + output = rearrange(output, "b h s d -> s (b h) d", b=1) + return output + + query_video = query_hnd[:, :, :mask_map.video_token_num, :] + key_video = key_hnd + value_video = value_hnd + kv_border = (pre_defined_mask[0].sum() + 63) // 64 + converted_mask[:, :, :, kv_border:] = False + output_video = block_sparse_sage2_attn_cuda( + query_video, + key_video, + value_video, + mask_id=converted_mask[:, :, :mask_map.video_token_num // block_size, :].contiguous(), + tensor_layout="HND", + ) + + # rearrange back to (s, h, d), we know that b = 1 + output_video = rearrange(output_video, "b h s d -> s (b h) d", b=1) + + # gt = sparse_sageattn( + # query_video, + # key_video, + # value_video, + # mask_id=None, + # is_causal=False, + # tensor_layout="HND", + # )[0] + + + + # import pdb; pdb.set_trace() + + output_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key[:pre_defined_mask[0].sum(), :, :], + v=value[:pre_defined_mask[0].sum(), :, :], + causal=False, + return_lse=False, + ) + + return torch.cat([output_video, output_text], dim=0) + + +def FlashInferBackend(query, key, value, mask_map=None, pre_defined_mask=None, bsr_wrapper=None): + if pre_defined_mask is not None: + video_video_o, video_video_o_lse = bsr_wrapper.run( + query[:mask_map.video_token_num, :, :], + key[:mask_map.video_token_num, :, :], + value[:mask_map.video_token_num, :, :], + return_lse=True + ) + # perform non-causal flashinfer on the text tokens + video_text_o, video_text_o_lse = flashinfer.single_prefill_with_kv_cache( + q=query[:mask_map.video_token_num, :, :], + k=key[mask_map.video_token_num:, :, :], + v=value[mask_map.video_token_num:, :, :], + causal=False, + return_lse=True, + custom_mask=pre_defined_mask[:mask_map.video_token_num, mask_map.video_token_num:] + ) + + # merge the two results + o_video, _ = flashinfer.merge_state(v_a=video_video_o, s_a=video_video_o_lse, v_b=video_text_o, s_b=video_text_o_lse) + + o_text = flashinfer.single_prefill_with_kv_cache( + q=query[mask_map.video_token_num:, :, :], + k=key, + v=value, + causal=False, + return_lse=False, + custom_mask=pre_defined_mask[mask_map.video_token_num:, :] + ) + + return torch.cat([o_video, o_text], dim=0) + else: + o = bsr_wrapper.run( + query[:mask_map.video_token_num, :, :], + key[:mask_map.video_token_num, :, :], + value[:mask_map.video_token_num, :, :] + ) + return o + +def RadialAttention(query, key, value, mask_map=None, sparsity_type="radial", block_size=128, decay_factor=1, model_type=None, pre_defined_mask=None, use_sage_attention=False): + orig_seqlen, num_head, hidden_dim = query.shape + + if sparsity_type == "dense": + video_mask = torch.ones((mask_map.video_token_num // block_size, mask_map.video_token_num // block_size), device=query.device, dtype=torch.bool) + else: + video_mask = mask_map.queryLogMask(query, sparsity_type, block_size=block_size, decay_factor=decay_factor, model_type=model_type) if mask_map else None + + backend = "sparse_sageattn" if use_sage_attention else "flashinfer" + + if backend == "flashinfer": + video_mask = video_mask[:mask_map.video_token_num // block_size, :mask_map.video_token_num // block_size] + # perform block-sparse attention on the video tokens + workspace_buffer = torch.empty(128 * 1024 * 1024, device=query.device, dtype=torch.uint8) + bsr_wrapper = flashinfer.BlockSparseAttentionWrapper( + workspace_buffer, + backend="fa2", + ) + + indptr = get_indptr_from_mask(video_mask, query) + indices = get_indices_from_mask(video_mask, query) + + bsr_wrapper.plan( + indptr=indptr, + indices=indices, + M=mask_map.video_token_num, + N=mask_map.video_token_num, + R=block_size, + C=block_size, + num_qo_heads=num_head, + num_kv_heads=num_head, + head_dim=hidden_dim, + q_data_type=query.dtype, + kv_data_type=key.dtype, + o_data_type=query.dtype, + ) + + return FlashInferBackend(query, key, value, mask_map, pre_defined_mask, bsr_wrapper) + elif backend == "sparse_sageattn": + return SpargeSageAttnBackend(query, key, value, mask_map, video_mask, pre_defined_mask, block_size=block_size) + +if __name__ == "__main__": + query = torch.randn(1, 2, 4, 64).cuda() + # mask = torch.tensor([ + # [True, False, True, False], + # [False, True, False, True], + # [True, False, False, True], + # [False, True, True, False] + # ], dtype=torch.bool) + # indices = get_indices_from_mask(mask, query) + # indptr = get_indptr_from_mask(mask, query) + # print("Indices: ", indices) + # print("Indptr: ", indptr) + video_token_num = 3840 * 30 + num_frame = 30 + token_per_frame = video_token_num / num_frame + padded_video_token_num = ((video_token_num + 1) // 128 + 1) * 128 + print("padded: ", padded_video_token_num) + temporal_mask = gen_log_mask_shrinked(query, padded_video_token_num, video_token_num, num_frame, sparse_type="radial", decay_factor=1, model_type="hunyuan") + plt.figure(figsize=(10, 8), dpi=500) + + plt.imshow(temporal_mask.cpu().numpy()[:, :], cmap='hot') + plt.colorbar() + plt.title("Temporal Mask") + + plt.savefig("temporal_mask.png", + dpi=300, + bbox_inches='tight', + pad_inches=0.1) + + plt.close() + # save the mask tensor + torch.save(temporal_mask, "temporal_mask.pt") diff --git a/shared/radial_attention/inference.py b/shared/radial_attention/inference.py new file mode 100644 index 000000000..04ef186ab --- /dev/null +++ b/shared/radial_attention/inference.py @@ -0,0 +1,41 @@ +import torch +from diffusers.models.attention_processor import Attention +from diffusers.models.attention import AttentionModuleMixin +from .attention import WanSparseAttnProcessor +from .attn_mask import MaskMap + +def setup_radial_attention( + pipe, + height, + width, + num_frames, + dense_layers=0, + dense_timesteps=0, + decay_factor=1.0, + sparsity_type="radial", + use_sage_attention=False, +): + + num_frames = 1 + num_frames // (pipe.vae_scale_factor_temporal * pipe.transformer.config.patch_size[0]) + mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1] + frame_size = int(height // mod_value) * int(width // mod_value) + + AttnModule = WanSparseAttnProcessor + AttnModule.dense_block = dense_layers + AttnModule.dense_timestep = dense_timesteps + AttnModule.mask_map = MaskMap(video_token_num=frame_size * num_frames, num_frame=num_frames) + AttnModule.decay_factor = decay_factor + AttnModule.sparse_type = sparsity_type + AttnModule.use_sage_attention = use_sage_attention + + print(f"Replacing Wan attention with {sparsity_type} attention") + print(f"video token num: {AttnModule.mask_map.video_token_num}, num frames: {num_frames}") + print(f"dense layers: {dense_layers}, dense timesteps: {dense_timesteps}, decay factor: {decay_factor}") + + for layer_idx, m in enumerate(pipe.transformer.blocks): + m.attn1.processor.layer_idx = layer_idx + + for _, m in pipe.transformer.named_modules(): + if isinstance(m, AttentionModuleMixin) and hasattr(m.processor, 'layer_idx'): + layer_idx = m.processor.layer_idx + m.set_processor(AttnModule(layer_idx)) \ No newline at end of file diff --git a/shared/radial_attention/sparse_transformer.py b/shared/radial_attention/sparse_transformer.py new file mode 100644 index 000000000..96820d0f3 --- /dev/null +++ b/shared/radial_attention/sparse_transformer.py @@ -0,0 +1,424 @@ +# borrowed from svg-project/Sparse-VideoGen + +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from diffusers.models.transformers.transformer_wan import WanTransformerBlock, WanTransformer3DModel +from diffusers import WanPipeline +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +logger = logging.get_logger(__name__) +from typing import Any, Callable, Dict, List, Optional, Union +from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback +from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput +import torch.distributed as dist + +try: + from xfuser.core.distributed import ( + get_ulysses_parallel_world_size, + get_ulysses_parallel_rank, + get_sp_group + ) +except: + pass + +class WanTransformerBlock_Sparse(WanTransformerBlock): + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + rotary_emb: torch.Tensor, + numeral_timestep: Optional[int] = None, + ) -> torch.Tensor: + if temb.ndim == 4: + # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table.unsqueeze(0) + temb.float() + ).chunk(6, dim=2) + # batch_size, seq_len, 1, inner_dim + shift_msa = shift_msa.squeeze(2) + scale_msa = scale_msa.squeeze(2) + gate_msa = gate_msa.squeeze(2) + c_shift_msa = c_shift_msa.squeeze(2) + c_scale_msa = c_scale_msa.squeeze(2) + c_gate_msa = c_gate_msa.squeeze(2) + else: + # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B) + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( + self.scale_shift_table + temb.float() + ).chunk(6, dim=1) + + # 1. Self-attention + norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) + attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb, numerical_timestep=numeral_timestep) + hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states).contiguous() + + # 2. Cross-attention + norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states) + attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = hidden_states + attn_output + + # 3. Feed-forward + norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( + hidden_states + ) + ff_output = self.ffn(norm_hidden_states) + hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) + + return hidden_states + +class WanTransformer3DModel_Sparse(WanTransformer3DModel): + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + numeral_timestep: Optional[int] = None, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v) + if timestep.ndim == 2: + ts_seq_len = timestep.shape[1] + timestep = timestep.flatten() # batch_size * seq_len + else: + ts_seq_len = None + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len + ) + + if ts_seq_len is not None: + # batch_size, seq_len, 6, inner_dim + timestep_proj = timestep_proj.unflatten(2, (6, -1)) + else: + # batch_size, 6, inner_dim + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + # split video latents on dim TS + hidden_states = torch.chunk(hidden_states, get_ulysses_parallel_world_size(), dim=-2)[get_ulysses_parallel_rank()] + rotary_emb = ( + torch.chunk(rotary_emb[0], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()], + torch.chunk(rotary_emb[1], get_ulysses_parallel_world_size(), dim=1)[get_ulysses_parallel_rank()], + ) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, numeral_timestep=numeral_timestep + ) + else: + for block in self.blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states, + timestep_proj, + rotary_emb, + numeral_timestep=numeral_timestep, + ) + + # 5. Output norm, projection & unpatchify + if temb.ndim == 3: + # batch_size, seq_len, inner_dim (wan 2.2 ti2v) + shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2) + shift = shift.squeeze(2) + scale = scale.squeeze(2) + else: + # batch_size, inner_dim + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + + if dist.is_initialized() and get_ulysses_parallel_world_size() > 1: + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + +class WanPipeline_Sparse(WanPipeline): + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "np", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, defaults to `480`): + The height in pixels of the generated image. + width (`int`, defaults to `832`): + The width in pixels of the generated image. + num_frames (`int`, defaults to `81`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of + each denoising step during the inference. with the following arguments: `callback_on_step_end(self: + DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a + list of all tensors as specified by `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): + The dtype to use for the torch.amp.autocast. + + Examples: + + Returns: + [`~WanPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._current_timestep = None + self._interrupt = False + + device = self._execution_device + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, + device=device, + ) + + transformer_dtype = self.transformer.dtype + prompt_embeds = prompt_embeds.to(transformer_dtype) + if negative_prompt_embeds is not None: + negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + torch.float32, + device, + generator, + latents, + ) + + # 6. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + latent_model_input = latents.to(transformer_dtype) + timestep = t.expand(latents.shape[0]) + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + numeral_timestep=i, + )[0] + + if self.do_classifier_free_guidance: + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + numeral_timestep=i, + )[0] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + + self._current_timestep = None + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + video = self.vae.decode(latents, return_dict=False)[0] + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return WanPipelineOutput(frames=video) + +def replace_sparse_forward(): + WanTransformerBlock.forward = WanTransformerBlock_Sparse.forward + WanTransformer3DModel.forward = WanTransformer3DModel_Sparse.forward + WanPipeline.__call__ = WanPipeline_Sparse.__call__ \ No newline at end of file diff --git a/shared/radial_attention/utils.py b/shared/radial_attention/utils.py new file mode 100644 index 000000000..92b3c3e13 --- /dev/null +++ b/shared/radial_attention/utils.py @@ -0,0 +1,16 @@ +import os +import random +import numpy as np +import torch + +def set_seed(seed): + """ + Set the random seed for reproducibility. + """ + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False \ No newline at end of file diff --git a/shared/sage2_core.py b/shared/sage2_core.py index 04b77f8bb..33cd98cbf 100644 --- a/shared/sage2_core.py +++ b/shared/sage2_core.py @@ -28,18 +28,24 @@ try: from sageattention import _qattn_sm80 + if not hasattr(_qattn_sm80, "qk_int8_sv_f16_accum_f32_attn"): + _qattn_sm80 = torch.ops.sageattention_qattn_sm80 SM80_ENABLED = True except: SM80_ENABLED = False try: from sageattention import _qattn_sm89 + if not hasattr(_qattn_sm89, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): + _qattn_sm89 = torch.ops.sageattention_qattn_sm89 SM89_ENABLED = True except: SM89_ENABLED = False try: from sageattention import _qattn_sm90 + if not hasattr(_qattn_sm90, "qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf"): + _qattn_sm90 = torch.ops.sageattention_qattn_sm90 SM90_ENABLED = True except: SM90_ENABLED = False diff --git a/shared/utils/audio_video.py b/shared/utils/audio_video.py index b24530d5c..224cf3412 100644 --- a/shared/utils/audio_video.py +++ b/shared/utils/audio_video.py @@ -232,6 +232,9 @@ def save_video(tensor, retry=5): """Save tensor as video with configurable codec and container options.""" + if torch.is_tensor(tensor) and len(tensor.shape) == 4: + tensor = tensor.unsqueeze(0) + suffix = f'.{container}' cache_file = osp.join('/tmp', rand_name(suffix=suffix)) if save_file is None else save_file if not cache_file.endswith(suffix): diff --git a/shared/utils/download.py b/shared/utils/download.py new file mode 100644 index 000000000..ed035c0b7 --- /dev/null +++ b/shared/utils/download.py @@ -0,0 +1,110 @@ +import sys, time + +# Global variables to track download progress +_start_time = None +_last_time = None +_last_downloaded = 0 +_speed_history = [] +_update_interval = 0.5 # Update speed every 0.5 seconds + +def progress_hook(block_num, block_size, total_size, filename=None): + """ + Simple progress bar hook for urlretrieve + + Args: + block_num: Number of blocks downloaded so far + block_size: Size of each block in bytes + total_size: Total size of the file in bytes + filename: Name of the file being downloaded (optional) + """ + global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval + + current_time = time.time() + downloaded = block_num * block_size + + # Initialize timing on first call + if _start_time is None or block_num == 0: + _start_time = current_time + _last_time = current_time + _last_downloaded = 0 + _speed_history = [] + + # Calculate download speed only at specified intervals + speed = 0 + if current_time - _last_time >= _update_interval: + if _last_time > 0: + current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) + _speed_history.append(current_speed) + # Keep only last 5 speed measurements for smoothing + if len(_speed_history) > 5: + _speed_history.pop(0) + # Average the recent speeds for smoother display + speed = sum(_speed_history) / len(_speed_history) + + _last_time = current_time + _last_downloaded = downloaded + elif _speed_history: + # Use the last calculated average speed + speed = sum(_speed_history) / len(_speed_history) + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + file_display = filename if filename else "Unknown file" + + if total_size <= 0: + # If total size is unknown, show downloaded bytes + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(80)) + sys.stdout.flush() + return + + downloaded = block_num * block_size + percent = min(100, (downloaded / total_size) * 100) + + # Create progress bar (40 characters wide to leave room for other info) + bar_length = 40 + filled = int(bar_length * percent / 100) + bar = '█' * filled + '░' * (bar_length - filled) + + # Format file sizes and speed + def format_bytes(bytes_val): + for unit in ['B', 'KB', 'MB', 'GB']: + if bytes_val < 1024: + return f"{bytes_val:.1f}{unit}" + bytes_val /= 1024 + return f"{bytes_val:.1f}TB" + + speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" + + # Display progress with filename first + line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" + # Clear any trailing characters by padding with spaces + sys.stdout.write(line.ljust(100)) + sys.stdout.flush() + + # Print newline when complete + if percent >= 100: + print() + +# Wrapper function to include filename in progress hook +def create_progress_hook(filename): + """Creates a progress hook with the filename included""" + global _start_time, _last_time, _last_downloaded, _speed_history + # Reset timing variables for new download + _start_time = None + _last_time = None + _last_downloaded = 0 + _speed_history = [] + + def hook(block_num, block_size, total_size): + return progress_hook(block_num, block_size, total_size, filename) + return hook + + diff --git a/shared/utils/loras_mutipliers.py b/shared/utils/loras_mutipliers.py index 58cc9a993..e0cec6a45 100644 --- a/shared/utils/loras_mutipliers.py +++ b/shared/utils/loras_mutipliers.py @@ -99,7 +99,7 @@ def is_float(element: any) -> bool: return loras_list_mult_choices_nums, slists_dict, "" -def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None ): +def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_step = None, phase_switch_step2 = None): from mmgp import offload sz = len(slists_dict["phase1"]) slists = [ expand_slist(slists_dict, i, num_inference_steps, phase_switch_step, phase_switch_step2 ) for i in range(sz) ] @@ -108,7 +108,8 @@ def update_loras_slists(trans, slists_dict, num_inference_steps, phase_switch_st -def get_model_switch_steps(timesteps, total_num_steps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): +def get_model_switch_steps(timesteps, guide_phases, model_switch_phase, switch_threshold, switch2_threshold ): + total_num_steps = len(timesteps) model_switch_step = model_switch_step2 = None for i, t in enumerate(timesteps): if guide_phases >=2 and model_switch_step is None and t <= switch_threshold: model_switch_step = i diff --git a/shared/utils/qwen_vl_utils.py b/shared/utils/qwen_vl_utils.py index 3c682e6ad..3fe02cc1a 100644 --- a/shared/utils/qwen_vl_utils.py +++ b/shared/utils/qwen_vl_utils.py @@ -9,7 +9,6 @@ import sys import time import warnings -from functools import lru_cache from io import BytesIO import requests @@ -257,7 +256,6 @@ def _read_video_decord(ele: dict,) -> torch.Tensor: FORCE_QWENVL_VIDEO_READER = os.getenv("FORCE_QWENVL_VIDEO_READER", None) -@lru_cache(maxsize=1) def get_video_reader_backend() -> str: if FORCE_QWENVL_VIDEO_READER is not None: video_reader_backend = FORCE_QWENVL_VIDEO_READER diff --git a/shared/utils/sliding_window_cleanup.py b/shared/utils/sliding_window_cleanup.py new file mode 100644 index 000000000..0864a3338 --- /dev/null +++ b/shared/utils/sliding_window_cleanup.py @@ -0,0 +1,77 @@ +""" +Sliding Window Video Cleanup Utility + +This module provides functionality to automatically delete shorter intermediate videos +during sliding window generation to save disk space. +""" + +import os + + +def cleanup_previous_video(previous_video_path, file_list, file_settings_list, lock, + sliding_window=True, sliding_window_keep_only_longest=True): + """ + Delete a previous video file and remove it from UI file lists. + + Args: + previous_video_path: Path to the video file to delete + file_list: List of file paths in UI preview + file_settings_list: List of file settings in UI preview + lock: Threading lock for safe list manipulation + sliding_window: Whether sliding window is enabled (default: True) + sliding_window_keep_only_longest: Whether cleanup is enabled (default: True) + + Returns: + bool: True if cleanup was successful, False otherwise + """ + # Check if cleanup should be performed + if not sliding_window or not sliding_window_keep_only_longest or previous_video_path is None: + return False + + try: + if os.path.isfile(previous_video_path): + os.remove(previous_video_path) + + # Also remove from UI file list to prevent dead links in preview + with lock: + if previous_video_path in file_list: + index = file_list.index(previous_video_path) + file_list.pop(index) + if index < len(file_settings_list): + file_settings_list.pop(index) + + return True + except Exception as e: + # Silently handle cleanup errors to not interrupt generation + pass + + return False + + +def should_cleanup_video(sliding_window, sliding_window_keep_only_longest, is_image=False): + """ + Determine if video cleanup should be performed. + + Args: + sliding_window: Whether sliding window is enabled + sliding_window_keep_only_longest: Whether cleanup is enabled + is_image: Whether the output is an image (cleanup not needed) + + Returns: + bool: True if cleanup should be performed + """ + return sliding_window and sliding_window_keep_only_longest and not is_image + + +def get_cleanup_status_text(server_config): + """ + Get the cleanup status text for info messages. + + Args: + server_config: Server configuration dictionary + + Returns: + str: "Enabled" or "Disabled" + """ + return "Enabled" if server_config.get("sliding_window_keep_only_longest", False) else "Disabled" + diff --git a/shared/utils/utils.py b/shared/utils/utils.py index a55807a17..a03927e62 100644 --- a/shared/utils/utils.py +++ b/shared/utils/utils.py @@ -1,4 +1,3 @@ -# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import os import os.path as osp @@ -18,11 +17,12 @@ import tempfile import subprocess import json - +from functools import lru_cache +os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") from PIL import Image - +video_info_cache = [] def seed_everything(seed: int): random.seed(seed) np.random.seed(seed) @@ -32,6 +32,14 @@ def seed_everything(seed: int): if torch.backends.mps.is_available(): torch.mps.manual_seed(seed) +def has_video_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".mp4"] + +def has_image_file_extension(filename): + extension = os.path.splitext(filename)[-1].lower() + return extension in [".png", ".jpg", ".jpeg", ".bmp", ".gif", ".webp", ".tif", ".tiff", ".jfif", ".pjpeg"] + def resample(video_fps, video_frames_count, max_target_frames_count, target_fps, start_target_frame ): import math @@ -77,7 +85,9 @@ def truncate_for_filesystem(s, max_bytes=255): else: r = m - 1 return s[:l] +@lru_cache(maxsize=100) def get_video_info(video_path): + global video_info_cache import cv2 cap = cv2.VideoCapture(video_path) @@ -92,7 +102,7 @@ def get_video_info(video_path): return fps, width, height, frame_count -def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, return_PIL = True) -> torch.Tensor: +def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool = False, target_fps = None, return_PIL = True) -> torch.Tensor: """Extract nth frame from video as PyTorch tensor normalized to [-1, 1].""" cap = cv2.VideoCapture(file_name) @@ -100,7 +110,10 @@ def get_video_frame(file_name: str, frame_no: int, return_last_if_missing: bool raise ValueError(f"Cannot open video: {file_name}") total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - + fps = round(cap.get(cv2.CAP_PROP_FPS)) + if target_fps is not None: + frame_no = round(target_fps * frame_no /fps) + # Handle out of bounds if frame_no >= total_frames or frame_no < 0: if return_last_if_missing: @@ -173,9 +186,15 @@ def remove_background(img, session=None): def convert_image_to_tensor(image): return torch.from_numpy(np.array(image).astype(np.float32)).div_(127.5).sub_(1.).movedim(-1, 0) -def convert_tensor_to_image(t, frame_no = -1): - t = t[:, frame_no] if frame_no >= 0 else t - return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) +def convert_tensor_to_image(t, frame_no = 0, mask_levels = False): + if len(t.shape) == 4: + t = t[:, frame_no] + if t.shape[0]== 1: + t = t.expand(3,-1,-1) + if mask_levels: + return Image.fromarray(t.clone().mul_(255).permute(1,2,0).to(torch.uint8).cpu().numpy()) + else: + return Image.fromarray(t.clone().add_(1.).mul_(127.5).permute(1,2,0).to(torch.uint8).cpu().numpy()) def save_image(tensor_image, name, frame_no = -1): convert_tensor_to_image(tensor_image, frame_no).save(name) @@ -186,6 +205,14 @@ def get_outpainting_full_area_dimensions(frame_height,frame_width, outpainting_d frame_width = int(frame_width * (100 + outpainting_left + outpainting_right) / 100) return frame_height, frame_width +def rgb_bw_to_rgba_mask(img, thresh=127): + a = img.convert('L').point(lambda p: 255 if p > thresh else 0) # alpha + out = Image.new('RGBA', img.size, (255, 255, 255, 0)) # white, transparent + out.putalpha(a) # white where alpha=255 + return out + + + def get_outpainting_frame_location(final_height, final_width, outpainting_dims, block_size = 8): outpainting_top, outpainting_bottom, outpainting_left, outpainting_right= outpainting_dims raw_height = int(final_height / ((100 + outpainting_top + outpainting_bottom) / 100)) @@ -205,30 +232,64 @@ def get_outpainting_frame_location(final_height, final_width, outpainting_dims if (margin_left + width) > final_width or outpainting_right == 0: margin_left = final_width - width return height, width, margin_top, margin_left -def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): - if fit_into_canvas == None: +def rescale_and_crop(img, w, h): + ow, oh = img.size + target_ratio = w / h + orig_ratio = ow / oh + + if orig_ratio > target_ratio: + # Crop width first + nw = int(oh * target_ratio) + img = img.crop(((ow - nw) // 2, 0, (ow + nw) // 2, oh)) + else: + # Crop height first + nh = int(ow / target_ratio) + img = img.crop((0, (oh - nh) // 2, ow, (oh + nh) // 2)) + + return img.resize((w, h), Image.LANCZOS) + +def calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = 16): + if fit_into_canvas == None or fit_into_canvas == 2: # return image_height, image_width return canvas_height, canvas_width - if fit_into_canvas: + if fit_into_canvas == 1: scale1 = min(canvas_height / image_height, canvas_width / image_width) scale2 = min(canvas_width / image_height, canvas_height / image_width) scale = max(scale1, scale2) - else: + else: #0 or #2 (crop) scale = (canvas_height * canvas_width / (image_height * image_width))**(1/2) new_height = round( image_height * scale / block_size) * block_size new_width = round( image_width * scale / block_size) * block_size return new_height, new_width -def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, ignore_first, fit_into_canvas = False ): +def calculate_dimensions_and_resize_image(image, canvas_height, canvas_width, fit_into_canvas, fit_crop, block_size = 16): + if fit_crop: + image = rescale_and_crop(image, canvas_width, canvas_height) + new_width, new_height = image.size + else: + image_width, image_height = image.size + new_height, new_width = calculate_new_dimensions(canvas_height, canvas_width, image_height, image_width, fit_into_canvas, block_size = block_size ) + image = image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + return image, new_height, new_width + +def resize_and_remove_background(img_list, budget_width, budget_height, rm_background, any_background_ref, fit_into_canvas = 0, block_size= 16, outpainting_dims = None, background_ref_outpainted = True, inpaint_color = 127.5, return_tensor = False ): if rm_background: session = new_session() output_list =[] + output_mask_list =[] for i, img in enumerate(img_list): width, height = img.size - - if fit_into_canvas: + resized_mask = None + if any_background_ref == 1 and i==0 or any_background_ref == 2: + if outpainting_dims is not None and background_ref_outpainted: + resized_image, resized_mask = fit_image_into_canvas(img, (budget_height, budget_width), inpaint_color, full_frame = True, outpainting_dims = outpainting_dims, return_mask= True, return_image= True) + elif img.size != (budget_width, budget_height): + resized_image= img.resize((budget_width, budget_height), resample=Image.Resampling.LANCZOS) + else: + resized_image =img + elif fit_into_canvas == 1: white_canvas = np.ones((budget_height, budget_width, 3), dtype=np.uint8) * 255 scale = min(budget_height / height, budget_width / width) new_height = int(height * scale) @@ -240,152 +301,112 @@ def resize_and_remove_background(img_list, budget_width, budget_height, rm_backg resized_image = Image.fromarray(white_canvas) else: scale = (budget_height * budget_width / (height * width))**(1/2) - new_height = int( round(height * scale / 16) * 16) - new_width = int( round(width * scale / 16) * 16) + new_height = int( round(height * scale / block_size) * block_size) + new_width = int( round(width * scale / block_size) * block_size) resized_image= img.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) - if rm_background and not (ignore_first and i == 0) : + if rm_background and not (any_background_ref and i==0 or any_background_ref == 2) : # resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1,alpha_matting_background_threshold = 70, alpha_foreground_background_threshold = 100, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') resized_image = remove(resized_image, session=session, alpha_matting_erode_size = 1, alpha_matting = True, bgcolor=[255, 255, 255, 0]).convert('RGB') - output_list.append(resized_image) #alpha_matting_background_threshold = 30, alpha_foreground_background_threshold = 200, - return output_list - - - - -def str2bool(v): - """ - Convert a string to a boolean. - - Supported true values: 'yes', 'true', 't', 'y', '1' - Supported false values: 'no', 'false', 'f', 'n', '0' - - Args: - v (str): String to convert. - - Returns: - bool: Converted boolean value. - - Raises: - argparse.ArgumentTypeError: If the value cannot be converted to boolean. - """ - if isinstance(v, bool): - return v - v_lower = v.lower() - if v_lower in ('yes', 'true', 't', 'y', '1'): - return True - elif v_lower in ('no', 'false', 'f', 'n', '0'): - return False + if return_tensor: + output_list.append(convert_image_to_tensor(resized_image).unsqueeze(1)) + else: + output_list.append(resized_image) + output_mask_list.append(resized_mask) + return output_list, output_mask_list + +def fit_image_into_canvas(ref_img, image_size, canvas_tf_bg =127.5, device ="cpu", full_frame = False, outpainting_dims = None, return_mask = False, return_image = False): + from shared.utils.utils import save_image + inpaint_color = canvas_tf_bg / 127.5 - 1 + + ref_width, ref_height = ref_img.size + if (ref_height, ref_width) == image_size and outpainting_dims == None: + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + canvas = torch.zeros_like(ref_img[:1]) if return_mask else None else: - raise argparse.ArgumentTypeError('Boolean value expected (True/False)') - - -import sys, time - -# Global variables to track download progress -_start_time = None -_last_time = None -_last_downloaded = 0 -_speed_history = [] -_update_interval = 0.5 # Update speed every 0.5 seconds - -def progress_hook(block_num, block_size, total_size, filename=None): - """ - Simple progress bar hook for urlretrieve - - Args: - block_num: Number of blocks downloaded so far - block_size: Size of each block in bytes - total_size: Total size of the file in bytes - filename: Name of the file being downloaded (optional) - """ - global _start_time, _last_time, _last_downloaded, _speed_history, _update_interval - - current_time = time.time() - downloaded = block_num * block_size - - # Initialize timing on first call - if _start_time is None or block_num == 0: - _start_time = current_time - _last_time = current_time - _last_downloaded = 0 - _speed_history = [] - - # Calculate download speed only at specified intervals - speed = 0 - if current_time - _last_time >= _update_interval: - if _last_time > 0: - current_speed = (downloaded - _last_downloaded) / (current_time - _last_time) - _speed_history.append(current_speed) - # Keep only last 5 speed measurements for smoothing - if len(_speed_history) > 5: - _speed_history.pop(0) - # Average the recent speeds for smoother display - speed = sum(_speed_history) / len(_speed_history) - - _last_time = current_time - _last_downloaded = downloaded - elif _speed_history: - # Use the last calculated average speed - speed = sum(_speed_history) / len(_speed_history) - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - file_display = filename if filename else "Unknown file" - - if total_size <= 0: - # If total size is unknown, show downloaded bytes - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - line = f"\r{file_display}: {format_bytes(downloaded)}{speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(80)) - sys.stdout.flush() - return - - downloaded = block_num * block_size - percent = min(100, (downloaded / total_size) * 100) - - # Create progress bar (40 characters wide to leave room for other info) - bar_length = 40 - filled = int(bar_length * percent / 100) - bar = '█' * filled + '░' * (bar_length - filled) - - # Format file sizes and speed - def format_bytes(bytes_val): - for unit in ['B', 'KB', 'MB', 'GB']: - if bytes_val < 1024: - return f"{bytes_val:.1f}{unit}" - bytes_val /= 1024 - return f"{bytes_val:.1f}TB" - - speed_str = f" @ {format_bytes(speed)}/s" if speed > 0 else "" - - # Display progress with filename first - line = f"\r{file_display}: [{bar}] {percent:.1f}% ({format_bytes(downloaded)}/{format_bytes(total_size)}){speed_str}" - # Clear any trailing characters by padding with spaces - sys.stdout.write(line.ljust(100)) - sys.stdout.flush() - - # Print newline when complete - if percent >= 100: - print() - -# Wrapper function to include filename in progress hook -def create_progress_hook(filename): - """Creates a progress hook with the filename included""" - global _start_time, _last_time, _last_downloaded, _speed_history - # Reset timing variables for new download - _start_time = None - _last_time = None - _last_downloaded = 0 - _speed_history = [] - - def hook(block_num, block_size, total_size): - return progress_hook(block_num, block_size, total_size, filename) - return hook + if outpainting_dims != None: + final_height, final_width = image_size + canvas_height, canvas_width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) + else: + canvas_height, canvas_width = image_size + if full_frame: + new_height = canvas_height + new_width = canvas_width + top = left = 0 + else: + # if fill_max and (canvas_height - new_height) < 16: + # new_height = canvas_height + # if fill_max and (canvas_width - new_width) < 16: + # new_width = canvas_width + scale = min(canvas_height / ref_height, canvas_width / ref_width) + new_height = int(ref_height * scale) + new_width = int(ref_width * scale) + top = (canvas_height - new_height) // 2 + left = (canvas_width - new_width) // 2 + ref_img = ref_img.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1) + if outpainting_dims != None: + canvas = torch.full((3, 1, final_height, final_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = ref_img + else: + canvas = torch.full((3, 1, canvas_height, canvas_width), inpaint_color, dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = ref_img + ref_img = canvas + canvas = None + if return_mask: + if outpainting_dims != None: + canvas = torch.ones((1, 1, final_height, final_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, margin_top + top:margin_top + top + new_height, margin_left + left:margin_left + left + new_width] = 0 + else: + canvas = torch.ones((1, 1, canvas_height, canvas_width), dtype= torch.float, device=device) # [-1, 1] + canvas[:, :, top:top + new_height, left:left + new_width] = 0 + canvas = canvas.to(device) + if return_image: + return convert_tensor_to_image(ref_img), canvas + + return ref_img.to(device), canvas + +def prepare_video_guide_and_mask( video_guides, video_masks, pre_video_guide, image_size, current_video_length = 81, latent_size = 4, any_mask = False, any_guide_padding = False, guide_inpaint_color = 127.5, keep_video_guide_frames = [], inject_frames = [], outpainting_dims = None, device ="cpu"): + src_videos, src_masks = [], [] + inpaint_color_compressed = guide_inpaint_color/127.5 - 1 + prepend_count = pre_video_guide.shape[1] if pre_video_guide is not None else 0 + for guide_no, (cur_video_guide, cur_video_mask) in enumerate(zip(video_guides, video_masks)): + src_video, src_mask = cur_video_guide, cur_video_mask + if pre_video_guide is not None: + src_video = pre_video_guide if src_video is None else torch.cat( [pre_video_guide, src_video], dim=1) + if any_mask: + src_mask = torch.zeros_like(pre_video_guide[:1]) if src_mask is None else torch.cat( [torch.zeros_like(pre_video_guide[:1]), src_mask], dim=1) + + if any_guide_padding: + if src_video is None: + src_video = torch.full( (3, current_video_length, *image_size ), inpaint_color_compressed, dtype = torch.float, device= device) + elif src_video.shape[1] < current_video_length: + src_video = torch.cat([src_video, torch.full( (3, current_video_length - src_video.shape[1], *src_video.shape[-2:] ), inpaint_color_compressed, dtype = src_video.dtype, device= src_video.device) ], dim=1) + elif src_video is not None: + new_num_frames = (src_video.shape[1] - 1) // latent_size * latent_size + 1 + src_video = src_video[:, :new_num_frames] + + if any_mask and src_video is not None: + if src_mask is None: + src_mask = torch.ones_like(src_video[:1]) + elif src_mask.shape[1] < src_video.shape[1]: + src_mask = torch.cat([src_mask, torch.full( (1, src_video.shape[1]- src_mask.shape[1], *src_mask.shape[-2:] ), 1, dtype = src_video.dtype, device= src_video.device) ], dim=1) + else: + src_mask = src_mask[:, :src_video.shape[1]] + + if src_video is not None : + for k, keep in enumerate(keep_video_guide_frames): + if not keep: + pos = prepend_count + k + src_video[:, pos:pos+1] = inpaint_color_compressed + if any_mask: src_mask[:, pos:pos+1] = 1 + + for k, frame in enumerate(inject_frames): + if frame != None: + pos = prepend_count + k + src_video[:, pos:pos+1], msk = fit_image_into_canvas(frame, image_size, guide_inpaint_color, device, True, outpainting_dims, return_mask= any_mask) + if any_mask: src_mask[:, pos:pos+1] = msk + src_videos.append(src_video) + src_masks.append(src_mask) + return src_videos, src_masks diff --git a/wgp.py b/wgp.py index a151252ba..f0b0d6a3f 100644 --- a/wgp.py +++ b/wgp.py @@ -1,4 +1,5 @@ import os +os.environ["GRADIO_LANG"] = "en" # # os.environ.pop("TORCH_LOGS", None) # make sure no env var is suppressing/overriding # os.environ["TORCH_LOGS"]= "recompiles" import torch._logging as tlog @@ -21,7 +22,9 @@ import importlib from shared.utils import notification_sound from shared.utils.loras_mutipliers import preparse_loras_multipliers, parse_loras_multipliers -from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, get_video_frame +from shared.utils.utils import convert_tensor_to_image, save_image, get_video_info, get_file_creation_date, convert_image_to_video, calculate_new_dimensions, convert_image_to_tensor, calculate_dimensions_and_resize_image, rescale_and_crop, get_video_frame, resize_and_remove_background, rgb_bw_to_rgba_mask +from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions +from shared.utils.utils import has_video_file_extension, has_image_file_extension from shared.utils.audio_video import extract_audio_tracks, combine_video_with_audio_tracks, combine_and_concatenate_video_with_audio_tracks, cleanup_temp_audio_files, save_video, save_image from shared.utils.audio_video import save_image_metadata, read_image_metadata from shared.match_archi import match_nvidia_architecture @@ -35,6 +38,7 @@ import asyncio import inspect from shared.utils import prompt_parser +from shared.utils.sliding_window_cleanup import cleanup_previous_video, should_cleanup_video, get_cleanup_status_text import base64 import io from PIL import Image @@ -49,6 +53,7 @@ from preprocessing.matanyone import app as matanyone_app from tqdm import tqdm import requests +from shared.gradio.gallery import AdvancedMediaGallery # import torch._dynamo as dynamo # dynamo.config.recompile_limit = 2000 # default is 256 @@ -58,9 +63,9 @@ AUTOSAVE_FILENAME = "queue.zip" PROMPT_VARS_MAX = 10 -target_mmgp_version = "3.5.10" -WanGP_version = "8.21" -settings_version = 2.27 +target_mmgp_version = "3.6.0" +WanGP_version = "8.9" +settings_version = 2.38 max_source_video_frames = 3000 prompt_enhancer_image_caption_model, prompt_enhancer_image_caption_processor, prompt_enhancer_llm_model, prompt_enhancer_llm_tokenizer = None, None, None, None @@ -185,10 +190,23 @@ def compute_sliding_window_no(current_video_length, sliding_window_size, discard return 1 + math.ceil(left_after_first_window / (sliding_window_size - discard_last_frames - reuse_frames)) +def clean_image_list(gradio_list): + if not isinstance(gradio_list, list): gradio_list = [gradio_list] + gradio_list = [ tup[0] if isinstance(tup, tuple) else tup for tup in gradio_list ] + + if any( not isinstance(image, (Image.Image, str)) for image in gradio_list): return None + if any( isinstance(image, str) and not has_image_file_extension(image) for image in gradio_list): return None + gradio_list = [ convert_image( Image.open(img) if isinstance(img, str) else img ) for img in gradio_list ] + return gradio_list + + + def process_prompt_and_add_tasks(state, model_choice): - + def ret(): + return gr.update(), gr.update() + if state.get("validate_success",0) != 1: - return + ret() state["validate_success"] = 0 model_filename = state["model_filename"] @@ -205,10 +223,10 @@ def process_prompt_and_add_tasks(state, model_choice): if inputs == None: gr.Warning("Internal state error: Could not retrieve inputs for the model.") queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() model_def = get_model_def(model_type) model_handler = get_model_handler(model_type) - image_outputs = inputs["image_mode"] == 1 + image_outputs = inputs["image_mode"] > 0 any_steps_skipping = model_def.get("tea_cache", False) or model_def.get("mag_cache", False) model_type = get_base_model_type(model_type) inputs["model_filename"] = model_filename @@ -234,7 +252,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(temporal_upsampling) >0: prompt += ["Temporal Upsampling"] if has_image_file_extension(edit_video_source) and len(temporal_upsampling) > 0: gr.Info("Temporal Upsampling can not be used with an Image") - return + return ret() film_grain_intensity = inputs.get("film_grain_intensity",0) film_grain_saturation = inputs.get("film_grain_saturation",0.5) # if film_grain_intensity >0: prompt += [f"Film Grain: intensity={film_grain_intensity}, saturation={film_grain_saturation}"] @@ -250,7 +268,7 @@ def process_prompt_and_add_tasks(state, model_choice): else: if audio_source is None: gr.Info("You must provide a custom Audio") - return + return ret() prompt += ["Custom Audio"] repeat_generation == 1 @@ -260,32 +278,32 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("You must choose at least one Remux Method") else: gr.Info("You must choose at least one Post Processing Method") - return + return ret() inputs["prompt"] = ", ".join(prompt) add_video_task(**inputs) gen["prompts_max"] = 1 + gen.get("prompts_max",0) state["validate_success"] = 1 queue= gen.get("queue", []) - return update_queue_data(queue) + return ret() if hasattr(model_handler, "validate_generative_settings"): error = model_handler.validate_generative_settings(model_type, model_def, inputs) if error is not None and len(error) > 0: gr.Info(error) - return + return ret() if inputs.get("cfg_star_switch", 0) != 0 and inputs.get("apg_switch", 0) != 0: gr.Info("Adaptive Progressive Guidance and Classifier Free Guidance Star can not be set at the same time") - return + return ret() prompt = inputs["prompt"] if len(prompt) ==0: gr.Info("Prompt cannot be empty.") gen = get_gen_info(state) queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() prompt, errors = prompt_parser.process_template(prompt) if len(errors) > 0: gr.Info("Error processing prompt template: " + errors) - return + return ret() model_filename = get_model_filename(model_type) prompts = prompt.replace("\r", "").split("\n") prompts = [prompt.strip() for prompt in prompts if len(prompt.strip())>0 and not prompt.startswith("#")] @@ -293,7 +311,7 @@ def process_prompt_and_add_tasks(state, model_choice): gr.Info("Prompt cannot be empty.") gen = get_gen_info(state) queue = gen.get("queue", []) - return get_queue_table(queue) + return ret() resolution = inputs["resolution"] width, height = resolution.split("x") @@ -335,30 +353,36 @@ def process_prompt_and_add_tasks(state, model_choice): model_switch_phase = inputs["model_switch_phase"] switch_threshold = inputs["switch_threshold"] switch_threshold2 = inputs["switch_threshold2"] - + multi_prompts_gen_type = inputs["multi_prompts_gen_type"] + video_guide_outpainting = inputs["video_guide_outpainting"] + + outpainting_dims = get_outpainting_dims(video_guide_outpainting) + + if server_config.get("fit_canvas", 0) == 2 and outpainting_dims is not None and any_letters(video_prompt_type, "VKF"): + gr.Info("Output Resolution Cropping will be not used for this Generation as it is not compatible with Video Outpainting") if len(loras_multipliers) > 0: _, _, errors = parse_loras_multipliers(loras_multipliers, len(activated_loras), num_inference_steps, nb_phases= guidance_phases) if len(errors) > 0: gr.Info(f"Error parsing Loras Multipliers: {errors}") - return + return ret() if guidance_phases == 3: if switch_threshold < switch_threshold2: gr.Info(f"Phase 1-2 Switch Noise Level ({switch_threshold}) should be Greater than Phase 2-3 Switch Noise Level ({switch_threshold2}). As a reminder, noise will gradually go down from 1000 to 0.") - return + return ret() else: model_switch_phase = 1 if not any_steps_skipping: skip_steps_cache_type = "" if not model_def.get("lock_inference_steps", False) and model_type in ["ltxv_13B"] and num_inference_steps < 20: gr.Info("The minimum number of steps should be 20") - return + return ret() if skip_steps_cache_type == "mag": if num_inference_steps > 50: gr.Info("Mag Cache maximum number of steps is 50") - return + return ret() - if image_mode == 1: + if image_mode > 0: audio_prompt_type = "" if "B" in audio_prompt_type or "X" in audio_prompt_type: @@ -366,49 +390,53 @@ def process_prompt_and_add_tasks(state, model_choice): speakers_bboxes, error = parse_speakers_locations(speakers_locations) if len(error) > 0: gr.Info(error) - return + return ret() if MMAudio_setting != 0 and server_config.get("mmaudio_enabled", 0) != 0 and video_length <16: #should depend on the architecture gr.Info("MMAudio can generate an Audio track only if the Video is at least 1s long") if "F" in video_prompt_type: if len(frames_positions.strip()) > 0: - positions = frames_positions.split(" ") + positions = frames_positions.replace(","," ").split(" ") for pos_str in positions: - if not is_integer(pos_str): - gr.Info(f"Invalid Frame Position '{pos_str}'") - return - pos = int(pos_str) - if pos <1 or pos > max_source_video_frames: - gr.Info(f"Invalid Frame Position Value'{pos_str}'") - return + if not pos_str in ["L", "l"] and len(pos_str)>0: + if not is_integer(pos_str): + gr.Info(f"Invalid Frame Position '{pos_str}'") + return ret() + pos = int(pos_str) + if pos <1 or pos > max_source_video_frames: + gr.Info(f"Invalid Frame Position Value'{pos_str}'") + return ret() else: frames_positions = None if audio_source is not None and MMAudio_setting != 0: gr.Info("MMAudio and Custom Audio Soundtrack can't not be used at the same time") - return + return ret() if len(filter_letters(image_prompt_type, "VLG")) > 0 and len(keep_frames_video_source) > 0: if not is_integer(keep_frames_video_source) or int(keep_frames_video_source) == 0: gr.Info("The number of frames to keep must be a non null integer") - return + return ret() else: keep_frames_video_source = "" + if image_outputs: + image_prompt_type = image_prompt_type.replace("V", "").replace("L", "") + if "V" in image_prompt_type: if video_source == None: gr.Info("You must provide a Source Video file to continue") - return + return ret() else: video_source = None if "A" in audio_prompt_type: if audio_guide == None: gr.Info("You must provide an Audio Source") - return + return ret() if "B" in audio_prompt_type: if audio_guide2 == None: gr.Info("You must provide a second Audio Source") - return + return ret() else: audio_guide2 = None else: @@ -422,24 +450,23 @@ def process_prompt_and_add_tasks(state, model_choice): if model_def.get("one_image_ref_needed", False): if image_refs == None : gr.Info("You must provide an Image Reference") - return + return ret() if len(image_refs) > 1: gr.Info("Only one Image Reference (a person) is supported for the moment by this model") - return + return ret() if model_def.get("at_least_one_image_ref_needed", False): if image_refs == None : gr.Info("You must provide at least one Image Reference") - return + return ret() if "I" in video_prompt_type: if image_refs == None or len(image_refs) == 0: - gr.Info("You must provide at least one Refererence Image") - return - if any(isinstance(image[0], str) for image in image_refs) : + gr.Info("You must provide at least one Reference Image") + return ret() + image_refs = clean_image_list(image_refs) + if image_refs == None : gr.Info("A Reference Image should be an Image") - return - if isinstance(image_refs, list): - image_refs = [ convert_image(tup[0]) for tup in image_refs ] + return ret() else: image_refs = None @@ -447,35 +474,36 @@ def process_prompt_and_add_tasks(state, model_choice): if image_outputs: if image_guide is None: gr.Info("You must provide a Control Image") - return + return ret() else: if video_guide is None: gr.Info("You must provide a Control Video") - return + return ret() if "A" in video_prompt_type and not "U" in video_prompt_type: if image_outputs: if image_mask is None: gr.Info("You must provide a Image Mask") - return + return ret() else: if video_mask is None: gr.Info("You must provide a Video Mask") - return + return ret() else: video_mask = None image_mask = None if "G" in video_prompt_type: - gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start a Step no {int(num_inference_steps * (1. - denoising_strength))} ") + if denoising_strength < 1.: + gr.Info(f"With Denoising Strength {denoising_strength:.1f}, denoising will start at Step no {int(round(num_inference_steps * (1. - denoising_strength),4))} ") else: denoising_strength = 1.0 if len(keep_frames_video_guide) > 0 and model_type in ["ltxv_13B"]: gr.Info("Keep Frames for Control Video is not supported with LTX Video") - return + return ret() _, error = parse_keep_frames_video_guide(keep_frames_video_guide, video_length) if len(error) > 0: gr.Info(f"Invalid Keep Frames property: {error}") - return + return ret() else: video_guide = None image_guide = None @@ -495,69 +523,76 @@ def process_prompt_and_add_tasks(state, model_choice): if "S" in image_prompt_type: if image_start == None or isinstance(image_start, list) and len(image_start) == 0: gr.Info("You must provide a Start Image") - return - if not isinstance(image_start, list): - image_start = [image_start] - if not all( not isinstance(img[0], str) for img in image_start) : + return ret() + image_start = clean_image_list(image_start) + if image_start == None : gr.Info("Start Image should be an Image") - return - image_start = [ convert_image(tup[0]) for tup in image_start ] + return ret() + if multi_prompts_gen_type == 1 and len(image_start) > 1: + gr.Info("Only one Start Image is supported") + return ret() else: image_start = None + if not any_letters(image_prompt_type, "SVL"): + image_prompt_type = image_prompt_type.replace("E", "") if "E" in image_prompt_type: if image_end == None or isinstance(image_end, list) and len(image_end) == 0: gr.Info("You must provide an End Image") - return - if not isinstance(image_end, list): - image_end = [image_end] - if not all( not isinstance(img[0], str) for img in image_end) : + return ret() + image_end = clean_image_list(image_end) + if image_end == None : gr.Info("End Image should be an Image") - return - if len(image_start) != len(image_end): - gr.Info("The number of Start and End Images should be the same ") - return - image_end = [ convert_image(tup[0]) for tup in image_end ] + return ret() + if multi_prompts_gen_type == 0: + if video_source is not None: + if len(image_end)> 1: + gr.Info("If a Video is to be continued and the option 'Each Text Prompt Will create a new generated Video' is set, there can be only one End Image") + return ret() + elif len(image_start or []) != len(image_end or []): + gr.Info("The number of Start and End Images should be the same when the option 'Each Text Prompt Will create a new generated Video'") + return ret() else: image_end = None if test_any_sliding_window(model_type) and image_mode == 0: if video_length > sliding_window_size: - full_video_length = video_length if video_source is None else video_length + sliding_window_overlap + if model_type in ["t2v"] and not "G" in video_prompt_type : + gr.Info(f"You have requested to Generate Sliding Windows with a Text to Video model. Unless you use the Video to Video feature this is useless as a t2v model doesn't see past frames and it will generate the same video in each new window.") + return ret() + full_video_length = video_length if video_source is None else video_length + sliding_window_overlap -1 extra = "" if full_video_length == video_length else f" including {sliding_window_overlap} added for Video Continuation" no_windows = compute_sliding_window_no(full_video_length, sliding_window_size, sliding_window_discard_last_frames, sliding_window_overlap) - gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated") - + cleanup_status = get_cleanup_status_text(server_config) + gr.Info(f"The Number of Frames to generate ({video_length}{extra}) is greater than the Sliding Window Size ({sliding_window_size}), {no_windows} Windows will be generated. Video Cleanup: {cleanup_status}") if "recam" in model_filename: - if video_source == None: - gr.Info("You must provide a Source Video") - return - - frames = get_resampled_video(video_source, 0, 81, get_computed_fps(force_fps, model_type , video_guide, video_source )) + if video_guide == None: + gr.Info("You must provide a Control Video") + return ret() + computed_fps = get_computed_fps(force_fps, model_type , video_guide, video_source ) + frames = get_resampled_video(video_guide, 0, 81, computed_fps) if len(frames)<81: - gr.Info("Recammaster source video should be at least 81 frames once the resampling at 16 fps has been done") - return - - + gr.Info(f"Recammaster Control video should be at least 81 frames once the resampling at {computed_fps} fps has been done") + return ret() if "hunyuan_custom_custom_edit" in model_filename: if len(keep_frames_video_guide) > 0: gr.Info("Filtering Frames with this model is not supported") - return + return ret() if inputs["multi_prompts_gen_type"] != 0: if image_start != None and len(image_start) > 1: gr.Info("Only one Start Image must be provided if multiple prompts are used for different windows") - return + return ret() - if image_end != None and len(image_end) > 1: - gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") - return + # if image_end != None and len(image_end) > 1: + # gr.Info("Only one End Image must be provided if multiple prompts are used for different windows") + # return override_inputs = { "image_start": image_start[0] if image_start !=None and len(image_start) > 0 else None, - "image_end": image_end[0] if image_end !=None and len(image_end) > 0 else None, + "image_end": image_end, #[0] if image_end !=None and len(image_end) > 0 else None, "image_refs": image_refs, "audio_guide": audio_guide, "audio_guide2": audio_guide2, @@ -597,7 +632,7 @@ def process_prompt_and_add_tasks(state, model_choice): if len(prompts) >= len(image_start): if len(prompts) % len(image_start) != 0: gr.Info("If there are more text prompts than input images the number of text prompts should be dividable by the number of images") - return + return ret() rep = len(prompts) // len(image_start) new_image_start = [] new_image_end = [] @@ -611,7 +646,7 @@ def process_prompt_and_add_tasks(state, model_choice): else: if len(image_start) % len(prompts) !=0: gr.Info("If there are more input images than text prompts the number of images should be dividable by the number of text prompts") - return + return ret() rep = len(image_start) // len(prompts) new_prompts = [] for i, _ in enumerate(image_start): @@ -633,19 +668,23 @@ def process_prompt_and_add_tasks(state, model_choice): override_inputs["prompt"] = single_prompt inputs.update(override_inputs) add_video_task(**inputs) + new_prompts_count = len(prompts) else: + new_prompts_count = 1 override_inputs["prompt"] = "\n".join(prompts) inputs.update(override_inputs) add_video_task(**inputs) - - gen["prompts_max"] = len(prompts) + gen.get("prompts_max",0) + new_prompts_count += gen.get("prompts_max",0) + gen["prompts_max"] = new_prompts_count state["validate_success"] = 1 queue= gen.get("queue", []) - return update_queue_data(queue) + first_time_in_queue = state.get("first_time_in_queue", True) + state["first_time_in_queue"] = True + return update_queue_data(queue, first_time_in_queue), gr.update(open=True) if new_prompts_count > 1 else gr.update() def get_preview_images(inputs): - inputs_to_query = ["image_start", "image_end", "video_source", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] - labels = ["Start Image", "End Image", "Video Source", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"] + inputs_to_query = ["image_start", "video_source", "image_end", "video_guide", "image_guide", "video_mask", "image_mask", "image_refs" ] + labels = ["Start Image", "Video Source", "End Image", "Video Guide", "Image Guide", "Video Mask", "Image Mask", "Image Reference"] start_image_data = None start_image_labels = [] end_image_data = None @@ -695,7 +734,6 @@ def add_video_task(**inputs): "start_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in start_image_data] if start_image_data != None else None, "end_image_data_base64": [pil_to_base64_uri(img, format="jpeg", quality=70) for img in end_image_data] if end_image_data != None else None }) - return update_queue_data(queue) def update_task_thumbnails(task, inputs): start_image_data, end_image_data, start_labels, end_labels = get_preview_images(inputs) @@ -1312,14 +1350,12 @@ def get_queue_table(queue): "✖" ]) return data -def update_queue_data(queue): +def update_queue_data(queue, first_time_in_queue =False): update_global_queue_ref(queue) data = get_queue_table(queue) - if len(data) == 0: - return gr.DataFrame(visible=False) - else: - return gr.DataFrame(value=data, visible= True) + return gr.DataFrame(value=data) + def create_html_progress_bar(percentage=0.0, text="Idle", is_idle=True): bar_class = "progress-bar-custom idle" if is_idle else "progress-bar-custom" @@ -1354,6 +1390,26 @@ def _parse_args(): help="save proprocessed audio track with extract speakers for debugging or editing" ) + parser.add_argument( + "--debug-gen-form", + action="store_true", + help="View form generation / refresh time" + ) + + parser.add_argument( + "--betatest", + action="store_true", + help="test unreleased features" + ) + + parser.add_argument( + "--vram-safety-coefficient", + type=float, + default=0.8, + help="max VRAM (between 0 and 1) that should be allocated to preloaded models" + ) + + parser.add_argument( "--share", action="store_true", @@ -1763,7 +1819,8 @@ def get_lora_dir(model_type): "vae_config": 0, "profile" : profile_type.LowRAM_LowVRAM, "preload_model_policy": [], - "UI_theme": "default" + "UI_theme": "default", + "sliding_window_keep_only_longest": False } with open(server_config_filename, "w", encoding="utf-8") as writer: @@ -1895,7 +1952,8 @@ def get_model_min_frames_and_step(model_type): mode_def = get_model_def(model_type) frames_minimum = mode_def.get("frames_minimum", 5) frames_steps = mode_def.get("frames_steps", 4) - return frames_minimum, frames_steps + latent_size = mode_def.get("latent_size", frames_steps) + return frames_minimum, frames_steps, latent_size def get_model_fps(model_type): mode_def = get_model_def(model_type) @@ -1959,8 +2017,10 @@ def get_model_recursive_prop(model_type, prop = "URLs", sub_prop_name = None, re raise Exception(f"Unknown model type '{model_type}'") -def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, stack=[]): - if module_type is not None: +def get_model_filename(model_type, quantization ="int8", dtype_policy = "", module_type = None, submodel_no = 1, URLs = None, stack=[]): + if URLs is not None: + pass + elif module_type is not None: base_model_type = get_base_model_type(model_type) # model_type_handler = model_types_handlers[base_model_type] # modules_files = model_type_handler.query_modules_files() if hasattr(model_type_handler, "query_modules_files") else {} @@ -1987,7 +2047,8 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu if len(stack) > 10: raise Exception(f"Circular Reference in Model {key_name} dependencies: {stack}") return get_model_filename(URLs, quantization=quantization, dtype_policy=dtype_policy, submodel_no = submodel_no, stack = stack + [URLs]) - choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + # choices = [ ("ckpts/" + os.path.basename(path) if path.startswith("http") else path) for path in URLs ] + choices = URLs if len(quantization) == 0: quantization = "bf16" @@ -1997,13 +2058,13 @@ def get_model_filename(model_type, quantization ="int8", dtype_policy = "", modu raw_filename = choices[0] else: if quantization in ("int8", "fp8"): - sub_choices = [ name for name in choices if quantization in name or quantization.upper() in name] + sub_choices = [ name for name in choices if quantization in os.path.basename(name) or quantization.upper() in os.path.basename(name)] else: - sub_choices = [ name for name in choices if "quanto" not in name] + sub_choices = [ name for name in choices if "quanto" not in os.path.basename(name)] if len(sub_choices) > 0: dtype_str = "fp16" if dtype == torch.float16 else "bf16" - new_sub_choices = [ name for name in sub_choices if dtype_str in name or dtype_str.upper() in name] + new_sub_choices = [ name for name in sub_choices if dtype_str in os.path.basename(name) or dtype_str.upper() in os.path.basename(name)] sub_choices = new_sub_choices if len(new_sub_choices) > 0 else sub_choices raw_filename = sub_choices[0] else: @@ -2045,8 +2106,6 @@ def fix_settings(model_type, ui_defaults): if image_prompt_type != None : if not isinstance(image_prompt_type, str): image_prompt_type = "S" if image_prompt_type == 0 else "SE" - # if model_type == "flf2v_720p" and not "E" in image_prompt_type: - # image_prompt_type = "SE" if settings_version <= 2: image_prompt_type = image_prompt_type.replace("G","") ui_defaults["image_prompt_type"] = image_prompt_type @@ -2064,12 +2123,13 @@ def fix_settings(model_type, ui_defaults): audio_prompt_type ="A" ui_defaults["audio_prompt_type"] = audio_prompt_type + if settings_version < 2.35 and any_audio_track(base_model_type): + audio_prompt_type = audio_prompt_type or "" + audio_prompt_type += "V" + ui_defaults["audio_prompt_type"] = audio_prompt_type video_prompt_type = ui_defaults.get("video_prompt_type", "") - any_reference_image = model_def.get("reference_image", False) - if base_model_type in ["hunyuan_custom", "hunyuan_custom_edit", "hunyuan_custom_audio", "hunyuan_avatar", "phantom_14B", "phantom_1.3B"] or any_reference_image: - if not "I" in video_prompt_type: # workaround for settings corruption - video_prompt_type += "I" + if base_model_type in ["hunyuan"]: video_prompt_type = video_prompt_type.replace("I", "") @@ -2108,10 +2168,28 @@ def fix_settings(model_type, ui_defaults): del ui_defaults["tea_cache_start_step_perc"] ui_defaults["skip_steps_start_step_perc"] = tea_cache_start_step_perc + image_prompt_type = ui_defaults.get("image_prompt_type", "") + if len(image_prompt_type) > 0: + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed","") + image_prompt_type = filter_letters(image_prompt_type, image_prompt_types_allowed) + ui_defaults["image_prompt_type"] = image_prompt_type + + video_prompt_type = ui_defaults.get("video_prompt_type", "") + image_ref_choices_list = model_def.get("image_ref_choices", {}).get("choices", []) + if model_def.get("guide_custom_choices", None) is None: + if len(image_ref_choices_list)==0: + video_prompt_type = del_in_sequence(video_prompt_type, "IK") + else: + first_choice = image_ref_choices_list[0][1] + if "I" in first_choice and not "I" in video_prompt_type: video_prompt_type += "I" + if len(image_ref_choices_list)==1 and "K" in first_choice and not "K" in video_prompt_type: video_prompt_type += "K" + ui_defaults["video_prompt_type"] = video_prompt_type + model_handler = get_model_handler(base_model_type) if hasattr(model_handler, "fix_settings"): model_handler.fix_settings(base_model_type, settings_version, model_def, ui_defaults) + def get_default_settings(model_type): def get_default_prompt(i2v): if i2v: @@ -2143,7 +2221,8 @@ def get_default_prompt(i2v): "slg_switch": 0, "slg_layers": [9], "slg_start_perc": 10, - "slg_end_perc": 90 + "slg_end_perc": 90, + "audio_prompt_type": "V", } model_handler = get_model_handler(model_type) model_handler.update_default_settings(base_model_type, model_def, ui_defaults) @@ -2291,7 +2370,7 @@ def init_model_def(model_type, model_def): lock_ui_compile = True -def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True ): +def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_module = False, filter = None, no_fp16_main_model = True, module_source_no = 1): model_def = get_model_def(model_type) # To save module and quantized modules # 1) set Transformer Model Quantization Type to 16 bits @@ -2302,10 +2381,10 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod if model_def == None: return if is_module: url_key = "modules" - source_key = "module_source" + source_key = "module_source" if module_source_no <=1 else "module_source2" else: url_key = "URLs" if submodel_no <=1 else "URLs" + str(submodel_no) - source_key = "source" + source_key = "source" if submodel_no <=1 else "source2" URLs= model_def.get(url_key, None) if URLs is None: return if isinstance(URLs, str): @@ -2320,6 +2399,9 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod print("Target Module files are missing") return URLs= URLs[0] + if isinstance(URLs, dict): + url_dict_key = "URLs" if module_source_no ==1 else "URLs2" + URLs = URLs[url_dict_key] for url in URLs: if "quanto" not in url and dtypestr in url: model_filename = os.path.basename(url) @@ -2353,8 +2435,12 @@ def save_model(model, model_type, dtype, config_file, submodel_no = 1, is_mod elif not os.path.isfile(quanto_filename): offload.save_model(model, quanto_filename, config_file_path=config_file, do_quantize= True, filter_sd=filter) print(f"New quantized file '{quanto_filename}' had been created for finetune Id '{model_type}'.") - model_def[url_key][0].append(quanto_filename) - saved_finetune_def["model"][url_key][0].append(quanto_filename) + if isinstance(model_def[url_key][0],dict): + model_def[url_key][0][url_dict_key].append(quanto_filename) + saved_finetune_def["model"][url_key][0][url_dict_key].append(quanto_filename) + else: + model_def[url_key][0].append(quanto_filename) + saved_finetune_def["model"][url_key][0].append(quanto_filename) update_model_def = True if update_model_def: with open(finetune_file, "w", encoding="utf-8") as writer: @@ -2409,6 +2495,14 @@ def preprocessor_wrapper(sd): return preprocessor_wrapper +def get_local_model_filename(model_filename): + if model_filename.startswith("http"): + local_model_filename = os.path.join("ckpts", os.path.basename(model_filename)) + else: + local_model_filename = model_filename + return local_model_filename + + def process_files_def(repoId, sourceFolderList, fileList): targetRoot = "ckpts/" @@ -2434,7 +2528,8 @@ def download_mmaudio(): } process_files_def(**enhancer_def) -def download_models(model_filename = None, model_type= None, module_type = None, submodel_no = 1): +download_shared_done = False +def download_models(model_filename = None, model_type= None, module_type = False, submodel_no = 1): def computeList(filename): if filename == None: return [] @@ -2445,15 +2540,16 @@ def computeList(filename): from urllib.request import urlretrieve - from shared.utils.utils import create_progress_hook + from shared.utils.download import create_progress_hook shared_def = { "repoId" : "DeepBeepMeep/Wan2.1", - "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "pyannote", "det_align", "" ], + "sourceFolderList" : [ "pose", "scribble", "flow", "depth", "mask", "wav2vec", "chinese-wav2vec2-base", "roformer", "pyannote", "det_align", "" ], "fileList" : [ ["dw-ll_ucoco_384.onnx", "yolox_l.onnx"],["netG_A_latest.pth"], ["raft-things.pth"], - ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors"], + ["depth_anything_v2_vitl.pth","depth_anything_v2_vitb.pth"], ["sam_vit_h_4b8939_fp16.safetensors", "model.safetensors", "config.json"], ["config.json", "feature_extractor_config.json", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"], ["config.json", "pytorch_model.bin", "preprocessor_config.json"], + ["model_bs_roformer_ep_317_sdr_12.9755.ckpt", "model_bs_roformer_ep_317_sdr_12.9755.yaml", "download_checks.json"], ["pyannote_model_wespeaker-voxceleb-resnet34-LM.bin", "pytorch_model_segmentation-3.0.bin"], ["detface.pt"], [ "flownet.pkl" ] ] } process_files_def(**shared_def) @@ -2476,6 +2572,9 @@ def computeList(filename): process_files_def(**enhancer_def) download_mmaudio() + global download_shared_done + download_shared_done = True + if model_filename is None: return def download_file(url,filename): @@ -2501,37 +2600,25 @@ def download_file(url,filename): base_model_type = get_base_model_type(model_type) model_def = get_model_def(model_type) - source = model_def.get("source", None) - module_source = model_def.get("module_source", None) + any_source = ("source2" if submodel_no ==2 else "source") in model_def + any_module_source = ("module_source2" if submodel_no ==2 else "module_source") in model_def model_type_handler = model_types_handlers[base_model_type] - - if source is not None and module_type is None or module_source is not None and module_type is not None: + local_model_filename = get_local_model_filename(model_filename) + + if any_source and not module_type or any_module_source and module_type: model_filename = None else: - if not os.path.isfile(model_filename): - if module_type is not None: - key_name = "modules" - URLs = module_type - if isinstance(module_type, str): - URLs = get_model_recursive_prop(module_type, key_name, sub_prop_name="_list", return_list= False) - else: - key_name = "URLs" if submodel_no <= 1 else f"URLs{submodel_no}" - URLs = get_model_recursive_prop(model_type, key_name, return_list= False) - if isinstance(URLs, str): - raise Exception("Missing model " + URLs) - use_url = model_filename - for url in URLs: - if os.path.basename(model_filename) in url: - use_url = url - break + if not os.path.isfile(local_model_filename): + url = model_filename + if not url.startswith("http"): - raise Exception(f"Model '{model_filename}' in field '{key_name}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") + raise Exception(f"Model '{model_filename}' was not found locally and no URL was provided to download it. Please add an URL in the model definition file.") try: - download_file(use_url, model_filename) + download_file(url, local_model_filename) except Exception as e: - if os.path.isfile(model_filename): os.remove(model_filename) - raise Exception(f"{key_name} '{use_url}' is invalid for Model '{model_filename}' : {str(e)}'") - + if os.path.isfile(local_model_filename): os.remove(local_model_filename) + raise Exception(f"'{url}' is invalid for Model '{local_model_filename}' : {str(e)}'") + if module_type: return model_filename = None preload_URLs = get_model_recursive_prop(model_type, "preload_URLs", return_list= True) @@ -2557,6 +2644,7 @@ def download_file(url,filename): except Exception as e: if os.path.isfile(filename): os.remove(filename) raise Exception(f"Lora URL '{url}' is invalid: {str(e)}'") + if module_type: return model_files = model_type_handler.query_model_files(computeList, base_model_type, model_filename, text_encoder_quantization) if not isinstance(model_files, list): model_files = [model_files] for one_repo in model_files: @@ -2743,7 +2831,8 @@ def load_models(model_type, override_profile = -1): model_filename2 = get_model_filename(model_type=model_type, quantization= "" if save_quantized else transformer_quantization, dtype_policy = transformer_dtype_policy, submodel_no=2) # !!!! else: model_filename2 = None - modules = get_model_recursive_prop(model_type, "modules", return_list= True) + modules = get_model_recursive_prop(model_type, "modules", return_list= True) + modules = [get_model_recursive_prop(module, "modules", sub_prop_name ="_list", return_list= True) if isinstance(module, str) else module for module in modules ] if save_quantized and "quanto" in model_filename: save_quantized = False print("Need to provide a non quantized model to create a quantized model to be saved") @@ -2760,6 +2849,7 @@ def load_models(model_type, override_profile = -1): transformer_dtype = torch.bfloat16 if "bf16" in model_filename or "BF16" in model_filename else transformer_dtype transformer_dtype = torch.float16 if "fp16" in model_filename or"FP16" in model_filename else transformer_dtype perc_reserved_mem_max = args.perc_reserved_mem_max + vram_safety_coefficient = args.vram_safety_coefficient model_file_list = [model_filename] model_type_list = [model_type] module_type_list = [None] @@ -2767,27 +2857,40 @@ def load_models(model_type, override_profile = -1): if model_filename2 != None: model_file_list += [model_filename2] model_type_list += [model_type] - module_type_list += [None] + module_type_list += [False] model_submodel_no_list += [2] for module_type in modules: - model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) - model_type_list.append(model_type) - module_type_list.append(module_type) - model_submodel_no_list.append(0) + if isinstance(module_type,dict): + URLs1 = module_type.get("URLs", None) + if URLs1 is None: raise Exception(f"No URLs defined for Module {module_type}") + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs1)) + URLs2 = module_type.get("URLs2", None) + if URLs2 is None: raise Exception(f"No URL2s defined for Module {module_type}") + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, URLs = URLs2)) + model_type_list += [model_type] * 2 + module_type_list += [True] * 2 + model_submodel_no_list += [1,2] + else: + model_file_list.append(get_model_filename(model_type, transformer_quantization, transformer_dtype, module_type= module_type)) + model_type_list.append(model_type) + module_type_list.append(True) + model_submodel_no_list.append(0) + local_model_file_list= [] for filename, file_model_type, file_module_type, submodel_no in zip(model_file_list, model_type_list, module_type_list, model_submodel_no_list): download_models(filename, file_model_type, file_module_type, submodel_no) + local_model_file_list.append( get_local_model_filename(filename) ) VAE_dtype = torch.float16 if server_config.get("vae_precision","16") == "16" else torch.float mixed_precision_transformer = server_config.get("mixed_precision","0") == "1" transformer_type = None - for submodel_no, filename in zip(model_submodel_no_list, model_file_list): - if submodel_no>=1: + for module_type, filename in zip(module_type_list, local_model_file_list): + if module_type is None: print(f"Loading Model '{filename}' ...") else: print(f"Loading Module '{filename}' ...") wan_model, pipe = model_types_handlers[base_model_type].load_model( - model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, - dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized) + local_model_file_list, model_type, base_model_type, model_def, quantizeTransformer = quantizeTransformer, text_encoder_quantization = text_encoder_quantization, + dtype = transformer_dtype, VAE_dtype = VAE_dtype, mixed_precision_transformer = mixed_precision_transformer, save_quantized = save_quantized, submodel_no_list = model_submodel_no_list, ) kwargs = {} profile = init_pipe(pipe, kwargs, override_profile) @@ -2796,7 +2899,7 @@ def load_models(model_type, override_profile = -1): loras_transformer = ["transformer"] if "transformer2" in pipe: loras_transformer += ["transformer2"] - offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , convertWeightsFloatTo = transformer_dtype, **kwargs) + offloadobj = offload.profile(pipe, profile_no= profile, compile = compile, quantizeTransformer = False, loras = loras_transformer, coTenantsMap= {}, perc_reserved_mem_max = perc_reserved_mem_max , vram_safety_coefficient = vram_safety_coefficient , convertWeightsFloatTo = transformer_dtype, **kwargs) if len(args.gpu) > 0: torch.set_default_device(args.gpu) transformer_type = model_type @@ -2825,7 +2928,7 @@ def generate_header(model_type, compile, attention_mode): description_container = [""] get_model_name(model_type, description_container) - model_filename = get_model_filename(model_type, transformer_quantization, transformer_dtype_policy) or "" + model_filename = os.path.basename(get_model_filename(model_type, transformer_quantization, transformer_dtype_policy)) or "" description = description_container[0] header = f"
{description}
" overridden_attention = get_overridden_attention(model_type) @@ -2891,6 +2994,7 @@ def apply_changes( state, image_output_codec_choice = None, audio_output_codec_choice = None, last_resolution_choice = None, + sliding_window_keep_only_longest_choice = False, ): if args.lock_config: return "
Config Locked
",*[gr.update()]*4 @@ -2928,6 +3032,7 @@ def apply_changes( state, "video_output_codec" : video_output_codec_choice, "image_output_codec" : image_output_codec_choice, "audio_output_codec" : audio_output_codec_choice, + "sliding_window_keep_only_longest" : sliding_window_keep_only_longest_choice, "last_model_type" : state["model_type"], "last_model_per_family": state["last_model_per_family"], "last_advanced_choice": state["advanced"], @@ -3297,7 +3402,8 @@ def check(src, cond): if not all_letters(src, pos): return False if neg is not None and any_letters(src, neg): return False return True - map_video_prompt = {"V" : "Control Video", ("VA", "U") : "Mask Video", "I" : "Reference Images"} + image_outputs = configs.get("image_mode",0) > 0 + map_video_prompt = {"V" : "Control Image" if image_outputs else "Control Video", ("VA", "U") : "Mask Image" if image_outputs else "Mask Video", "I" : "Reference Images"} map_image_prompt = {"V" : "Source Video", "L" : "Last Video", "S" : "Start Image", "E" : "End Image"} map_audio_prompt = {"A" : "Audio Source", "B" : "Audio Source #2"} video_other_prompts = [ v for s,v in map_image_prompt.items() if all_letters(video_image_prompt_type,s)] \ @@ -3348,6 +3454,7 @@ def check(src, cond): if multiple_submodels: video_guidance_scale += f" + Model Switch at {video_switch_threshold if video_model_switch_phase ==1 else video_switch_threshold2}" video_flow_shift = configs.get("flow_shift", None) + if image_outputs: video_flow_shift = None video_video_guide_outpainting = configs.get("video_guide_outpainting", "") video_outpainting = "" if len(video_video_guide_outpainting) > 0 and not video_video_guide_outpainting.startswith("#") \ @@ -3369,7 +3476,7 @@ def check(src, cond): if len(video_other_prompts) >0 : values += [video_other_prompts] labels += ["Other Prompts"] - if len(video_outpainting) >0 and any_letters(video_image_prompt_type, "VFK"): + if len(video_outpainting) >0: values += [video_outpainting] labels += ["Outpainting"] video_sample_solver = configs.get("sample_solver", "") @@ -3390,7 +3497,23 @@ def check(src, cond): if video_apg_switch is not None and video_apg_switch != 0: values += ["on"] labels += ["APG"] - + + control_net_weight = "" + if test_vace_module(video_model_type): + video_control_net_weight = configs.get("control_net_weight", 1) + if len(filter_letters(video_video_prompt_type, video_guide_processes))> 1: + video_control_net_weight2 = configs.get("control_net_weight2", 1) + control_net_weight = f"Vace #1={video_control_net_weight}, Vace #2={video_control_net_weight2}" + else: + control_net_weight = f"Vace={video_control_net_weight}" + control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "") + if len(control_net_weight_alt_name) >0: + if len(control_net_weight): control_net_weight += ", " + control_net_weight += control_net_weight_alt_name + "=" + str(configs.get("control_net_alt", 1)) + if len(control_net_weight) > 0: + values += [control_net_weight] + labels += ["Control Net Weights"] + video_skip_steps_cache_type = configs.get("skip_steps_cache_type", "") video_skip_steps_multiplier = configs.get("skip_steps_multiplier", 0) video_skip_steps_cache_start_step_perc = configs.get("skip_steps_start_step_perc", 0) @@ -3436,10 +3559,17 @@ def convert_image(image): from PIL import ImageOps from typing import cast + if isinstance(image, str): + image = Image.open(image) image = image.convert('RGB') return cast(Image, ImageOps.exif_transpose(image)) def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='torch'): + if isinstance(video_in, str) and has_image_file_extension(video_in): + video_in = Image.open(video_in) + if isinstance(video_in, Image.Image): + return torch.from_numpy(np.array(video_in).astype(np.uint8)).unsqueeze(0) + from shared.utils.utils import resample import decord @@ -3455,6 +3585,48 @@ def get_resampled_video(video_in, start_frame, max_frames, target_fps, bridge='t # print(f"frame nos: {frame_nos}") return frames_list +# def get_resampled_video(video_in, start_frame, max_frames, target_fps): +# from torchvision.io import VideoReader +# import torch +# from shared.utils.utils import resample + +# vr = VideoReader(video_in, "video") +# meta = vr.get_metadata()["video"] + +# fps = round(float(meta["fps"][0])) +# duration_s = float(meta["duration"][0]) +# num_src_frames = int(round(duration_s * fps)) # robust length estimate + +# if max_frames < 0: +# max_frames = max(int(num_src_frames / fps * target_fps + max_frames), 0) + +# frame_nos = resample( +# fps, num_src_frames, +# max_target_frames_count=max_frames, +# target_fps=target_fps, +# start_target_frame=start_frame +# ) +# if len(frame_nos) == 0: +# return torch.empty((0,)) # nothing to return + +# target_ts = [i / fps for i in frame_nos] + +# # Read forward once, grabbing frames when we pass each target timestamp +# frames = [] +# vr.seek(target_ts[0]) +# idx = 0 +# tol = 0.5 / fps # half-frame tolerance +# for frame in vr: +# t = float(frame["pts"]) # seconds +# if idx < len(target_ts) and t + tol >= target_ts[idx]: +# frames.append(frame["data"].permute(1,2,0)) # Tensor [H, W, C] +# idx += 1 +# if idx >= len(target_ts): +# break + +# return frames + + def get_preprocessor(process_type, inpaint_color): if process_type=="pose": from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator @@ -3519,19 +3691,22 @@ def get_preprocessor(process_type, inpaint_color): def process_images_multithread(image_processor, items, process_type, wrap_in_list = True, max_workers: int = os.cpu_count()/ 2) : if not items: return [] - max_workers = 11 + import concurrent.futures start_time = time.time() # print(f"Preprocessus:{process_type} started") if process_type in ["prephase", "upsample"]: if wrap_in_list : items = [ [img] for img in items] - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} - results = [None] * len(items) - for future in concurrent.futures.as_completed(futures): - idx = futures[future] - results[idx] = future.result() + if max_workers == 1: + results = [image_processor(img) for img in items] + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(image_processor, img): idx for idx, img in enumerate(items)} + results = [None] * len(items) + for future in concurrent.futures.as_completed(futures): + idx = futures[future] + results[idx] = future.result() if wrap_in_list: results = [ img[0] for img in results] @@ -3541,10 +3716,68 @@ def process_images_multithread(image_processor, items, process_type, wrap_in_lis end_time = time.time() # print(f"duration:{end_time-start_time:.1f}") - return results + return results + +def extract_faces_from_video_with_mask(input_video_path, input_mask_path, max_frames, start_frame, target_fps, size = 512): + if not input_video_path or max_frames <= 0: + return None, None + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + + any_mask = input_mask_path != None + video = get_resampled_video(input_video_path, start_frame, max_frames, target_fps) + if len(video) == 0: return None + frame_height, frame_width, _ = video[0].shape + num_frames = len(video) + if any_mask: + mask_video = get_resampled_video(input_mask_path, start_frame, max_frames, target_fps) + num_frames = min(num_frames, len(mask_video)) + if num_frames == 0: return None + video = video[:num_frames] + if any_mask: + mask_video = mask_video[:num_frames] -def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): - from shared.utils.utils import calculate_new_dimensions, get_outpainting_frame_location, get_outpainting_full_area_dimensions + from preprocessing.face_preprocessor import FaceProcessor + face_processor = FaceProcessor() + + face_list = [] + for frame_idx in range(num_frames): + frame = video[frame_idx].cpu().numpy() + # video[frame_idx] = None + if any_mask: + mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) + # mask_video[frame_idx] = None + if (frame_width, frame_height) != mask.size: + mask = mask.resize((frame_width, frame_height), resample=Image.Resampling.LANCZOS) + mask = np.array(mask) + alpha_mask = np.zeros((frame_height, frame_width, 3), dtype=np.uint8) + alpha_mask[mask > 127] = 1 + frame = frame * alpha_mask + frame = Image.fromarray(frame) + face = face_processor.process(frame, resize_to=size) + face_list.append(face) + + face_processor = None + gc.collect() + torch.cuda.empty_cache() + + face_tensor= torch.tensor(np.stack(face_list, dtype= np.float32) / 127.5 - 1).permute(-1, 0, 1, 2 ) # t h w c -> c t h w + if pad_frames > 0: + face_tensor= torch.cat([face_tensor[:, -1:].expand(-1, pad_frames, -1, -1), face_tensor ], dim=2) + + if args.save_masks: + from preprocessing.dwpose.pose import save_one_video + saved_faces_frames = [np.array(face) for face in face_list ] + save_one_video(f"faces.mp4", saved_faces_frames, fps=target_fps, quality=8, macro_block_size=None) + return face_tensor + +def get_default_workers(): + return os.cpu_count()/ 2 + +def preprocess_video_with_mask(input_video_path, input_mask_path, height, width, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size= 16, expand_scale = 2, process_type = "inpaint", process_type2 = None, to_bbox = False, RGB_Mask = False, negate_mask = False, process_outside_mask = None, inpaint_color = 127, outpainting_dims = None, proc_no = 1): def mask_to_xyxy_box(mask): rows, cols = np.where(mask == 255) @@ -3559,7 +3792,13 @@ def mask_to_xyxy_box(mask): box = [xmin, ymin, xmax, ymax] box = [int(x) for x in box] return box - + inpaint_color = int(inpaint_color) + pad_frames = 0 + if start_frame < 0: + pad_frames= -start_frame + max_frames += start_frame + start_frame = 0 + if not input_video_path or max_frames <= 0: return None, None any_mask = input_mask_path != None @@ -3585,6 +3824,9 @@ def mask_to_xyxy_box(mask): if len(video) == 0 or any_mask and len(mask_video) == 0: return None, None + if fit_crop and outpainting_dims != None: + fit_crop = False + fit_canvas = 0 if fit_canvas is not None else None frame_height, frame_width, _ = video[0].shape @@ -3599,7 +3841,7 @@ def mask_to_xyxy_box(mask): if outpainting_dims != None: final_height, final_width = height, width - height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 8) + height, width, margin_top, margin_left = get_outpainting_frame_location(final_height, final_width, outpainting_dims, 1) if any_mask: num_frames = min(len(video), len(mask_video)) @@ -3616,14 +3858,20 @@ def mask_to_xyxy_box(mask): # for frame_idx in range(num_frames): def prep_prephase(frame_idx): frame = Image.fromarray(video[frame_idx].cpu().numpy()) #.asnumpy() - frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) + if fit_crop: + frame = rescale_and_crop(frame, width, height) + else: + frame = frame.resize((width, height), resample=Image.Resampling.LANCZOS) frame = np.array(frame) if any_mask: if any_identity_mask: mask = np.full( (height, width, 3), 0, dtype= np.uint8) else: mask = Image.fromarray(mask_video[frame_idx].cpu().numpy()) #.asnumpy() - mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) + if fit_crop: + mask = rescale_and_crop(mask, width, height) + else: + mask = mask.resize((width, height), resample=Image.Resampling.LANCZOS) mask = np.array(mask) if len(mask.shape) == 3 and mask.shape[2] == 3: @@ -3636,7 +3884,7 @@ def prep_prephase(frame_idx): mask = op_expand(mask, kernel, iterations=3) _, mask = cv2.threshold(mask, 127.5, 255, cv2.THRESH_BINARY) - if to_bbox and np.sum(mask == 255) > 0: + if to_bbox and np.sum(mask == 255) > 0 : #or True x0, y0, x1, y1 = mask_to_xyxy_box(mask) mask = mask * 0 mask[y0:y1, x0:x1] = 255 @@ -3654,8 +3902,8 @@ def prep_prephase(frame_idx): return (target_frame, frame, mask) else: return (target_frame, None, None) - - proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False) + max_workers = get_default_workers() + proc_lists = process_images_multithread(prep_prephase, [frame_idx for frame_idx in range(num_frames)], "prephase", wrap_in_list= False, max_workers=max_workers) proc_list, proc_list_outside, proc_mask = [None] * len(proc_lists), [None] * len(proc_lists), [None] * len(proc_lists) for frame_idx, frame_group in enumerate(proc_lists): proc_list[frame_idx], proc_list_outside[frame_idx], proc_mask[frame_idx] = frame_group @@ -3664,11 +3912,11 @@ def prep_prephase(frame_idx): mask_video = None if preproc2 != None: - proc_list2 = process_images_multithread(preproc2, proc_list, process_type2) + proc_list2 = process_images_multithread(preproc2, proc_list, process_type2, max_workers=max_workers) #### to be finished ...or not - proc_list = process_images_multithread(preproc, proc_list, process_type) + proc_list = process_images_multithread(preproc, proc_list, process_type, max_workers=max_workers) if any_mask: - proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask) + proc_list_outside = process_images_multithread(preproc_outside, proc_list_outside, process_outside_mask, max_workers=max_workers) else: proc_list_outside = proc_mask = len(proc_list) * [None] @@ -3686,7 +3934,7 @@ def prep_prephase(frame_idx): full_frame= torch.full( (final_height, final_width, mask.shape[-1]), 255, dtype= torch.uint8, device= mask.device) full_frame[margin_top:margin_top+height, margin_left:margin_left+width] = mask mask = full_frame - masks.append(mask) + masks.append(mask[:, :, 0:1].clone()) else: masked_frame = processed_img @@ -3706,28 +3954,33 @@ def prep_prephase(frame_idx): proc_list[frame_no] = proc_list_outside[frame_no] = proc_mask[frame_no] = None - if args.save_masks: - from preprocessing.dwpose.pose import save_one_video - saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] - save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) - if any_mask: - saved_masks = [mask.cpu().numpy() for mask in masks ] - save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) + # if args.save_masks: + # from preprocessing.dwpose.pose import save_one_video + # saved_masked_frames = [mask.cpu().numpy() for mask in masked_frames ] + # save_one_video(f"masked_frames{'' if proc_no==1 else str(proc_no)}.mp4", saved_masked_frames, fps=target_fps, quality=8, macro_block_size=None) + # if any_mask: + # saved_masks = [mask.cpu().numpy() for mask in masks ] + # save_one_video("masks.mp4", saved_masks, fps=target_fps, quality=8, macro_block_size=None) preproc = None preproc_outside = None gc.collect() torch.cuda.empty_cache() + if pad_frames > 0: + masked_frames = masked_frames[0] * pad_frames + masked_frames + if any_mask: masked_frames = masks[0] * pad_frames + masks + masked_frames = torch.stack(masked_frames).permute(-1,0,1,2).float().div_(127.5).sub_(1.) + masks = torch.stack(masks).permute(-1,0,1,2).float().div_(255) if any_mask else None - return torch.stack(masked_frames), torch.stack(masks) if any_mask else None + return masked_frames, masks -def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, target_fps = 16, block_size = 16): +def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_canvas = None, fit_crop = False, target_fps = 16, block_size = 16): frames_list = get_resampled_video(video_in, start_frame, max_frames, target_fps) if len(frames_list) == 0: return None - if fit_canvas == None: + if fit_canvas == None or fit_crop: new_height = height new_width = width else: @@ -3745,7 +3998,10 @@ def preprocess_video(height, width, video_in, max_frames, start_frame=0, fit_can processed_frames_list = [] for frame in frames_list: frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8)) - frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) + if fit_crop: + frame = rescale_and_crop(frame, new_width, new_height) + else: + frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS) processed_frames_list.append(frame) np_frames = [np.array(frame) for frame in processed_frames_list] @@ -3844,7 +4100,7 @@ def perform_spatial_upsampling(sample, spatial_upsampling): frames_to_upsample = [sample[:, i] for i in range( sample.shape[1]) ] def upsample_frames(frame): return resize_lanczos(frame, h, w).unsqueeze(1) - sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False), dim=1) + sample = torch.cat(process_images_multithread(upsample_frames, frames_to_upsample, "upsample", wrap_in_list = False, max_workers=get_default_workers()), dim=1) frames_to_upsample = None return sample @@ -4085,9 +4341,10 @@ def process_prompt_enhancer(prompt_enhancer, original_prompts, image_start, ori prompt_images = [] if "I" in prompt_enhancer: if image_start != None: - prompt_images.append(image_start) + prompt_images += image_start if original_image_refs != None: - prompt_images += original_image_refs[:1] + prompt_images += original_image_refs[:1] + prompt_images = [Image.open(img) if isinstance(img,str) else img for img in prompt_images] if len(original_prompts) == 0 and not "T" in prompt_enhancer: return None else: @@ -4137,7 +4394,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri if image_start is None or not "I" in prompt_enhancer: image_start = [None] * num_prompts else: - image_start = [img[0] for img in image_start] + image_start = [convert_image(img[0]) for img in image_start] if len(image_start) == 1: image_start = image_start * num_prompts else: @@ -4157,6 +4414,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri download_models() gen = get_gen_info(state) + original_process_status = None while True: with gen_lock: process_status = gen.get("process_status", None) @@ -4192,7 +4450,7 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri original_image_refs = inputs["image_refs"] if original_image_refs is not None: original_image_refs = [ convert_image(tup[0]) for tup in original_image_refs ] - is_image = inputs["image_mode"] == 1 + is_image = inputs["image_mode"] > 0 seed = inputs["seed"] seed = set_seed(seed) enhanced_prompts = [] @@ -4224,6 +4482,9 @@ def enhance_prompt(state, prompt, prompt_enhancer, multi_images_gen_type, overri gr.Info(f'Prompt "{original_prompts[0][:100]}" has been enhanced') return prompt, prompt +def get_outpainting_dims(video_guide_outpainting): + return None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] + def generate_video( task, send_cmd, @@ -4273,6 +4534,7 @@ def generate_video( image_mask, control_net_weight, control_net_weight2, + control_net_weight_alt, mask_expand, audio_guide, audio_guide2, @@ -4325,6 +4587,7 @@ def remove_temp_filenames(temp_filenames_list): processes_names = { "pose": "Open Pose", "depth": "Depth Mask", "scribble" : "Shapes", "flow" : "Flow Map", "gray" : "Gray Levels", "inpaint" : "Inpaint Mask", "identity": "Identity Mask", "raw" : "Raw Format", "canny" : "Canny Edges"} global wan_model, offloadobj, reload_needed + sliding_window_keep_only_longest = server_config.get("sliding_window_keep_only_longest", False) gen = get_gen_info(state) torch.set_grad_enabled(False) if mode.startswith("edit_"): @@ -4336,7 +4599,10 @@ def remove_temp_filenames(temp_filenames_list): model_def = get_model_def(model_type) - is_image = image_mode == 1 + is_image = image_mode > 0 + set_video_prompt_type = model_def.get("set_video_prompt_type", None) + if set_video_prompt_type is not None: + video_prompt_type = add_to_sequence(video_prompt_type, set_video_prompt_type) if is_image: if min_frames_if_references >= 1000: video_length = min_frames_if_references - 1000 @@ -4347,18 +4613,17 @@ def remove_temp_filenames(temp_filenames_list): temp_filenames_list = [] if image_guide is not None and isinstance(image_guide, Image.Image): - video_guide = convert_image_to_video(image_guide) - temp_filenames_list.append(video_guide) - image_guide = None + video_guide = image_guide + image_guide = None if image_mask is not None and isinstance(image_mask, Image.Image): - video_mask = convert_image_to_video(image_mask) - temp_filenames_list.append(video_mask) - image_mask = None + video_mask = image_mask + image_mask = None + if model_def.get("no_background_removal", False): remove_background_images_ref = 0 + base_model_type = get_base_model_type(model_type) model_family = get_model_family(base_model_type) - fit_canvas = server_config.get("fit_canvas", 0) model_handler = get_model_handler(base_model_type) block_size = model_handler.get_vae_block_size(base_model_type) if hasattr(model_handler, "get_vae_block_size") else 16 @@ -4384,7 +4649,7 @@ def remove_temp_filenames(temp_filenames_list): return width, height = resolution.split("x") - width, height = int(width), int(height) + width, height = int(width) // block_size * block_size, int(height) // block_size * block_size default_image_size = (height, width) if slg_switch == 0: @@ -4443,22 +4708,14 @@ def remove_temp_filenames(temp_filenames_list): current_video_length = video_length # VAE Tiling device_mem_capacity = torch.cuda.get_device_properties(None).total_memory / 1048576 - - i2v = test_class_i2v(model_type) - diffusion_forcing = "diffusion_forcing" in model_filename - t2v = base_model_type in ["t2v"] - ltxv = "ltxv" in model_filename - vace = test_vace_module(base_model_type) - hunyuan_t2v = "hunyuan_video_720" in model_filename - hunyuan_i2v = "hunyuan_video_i2v" in model_filename + guide_inpaint_color = model_def.get("guide_inpaint_color", 127.5) + extract_guide_from_window_start = model_def.get("extract_guide_from_window_start", False) hunyuan_custom = "hunyuan_video_custom" in model_filename hunyuan_custom_audio = hunyuan_custom and "audio" in model_filename hunyuan_custom_edit = hunyuan_custom and "edit" in model_filename hunyuan_avatar = "hunyuan_video_avatar" in model_filename fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - standin = model_def.get("standin_class", False) - infinitetalk = base_model_type in ["infinitetalk"] if "B" in audio_prompt_type or "X" in audio_prompt_type: from models.wan.multitalk.multitalk import parse_speakers_locations @@ -4487,50 +4744,34 @@ def remove_temp_filenames(temp_filenames_list): if test_any_sliding_window(model_type) : if video_source is not None: - current_video_length += sliding_window_overlap + current_video_length += sliding_window_overlap - 1 sliding_window = current_video_length > sliding_window_size reuse_frames = min(sliding_window_size - 4, sliding_window_overlap) else: sliding_window = False + sliding_window_size = current_video_length reuse_frames = 0 - _, latent_size = get_model_min_frames_and_step(model_type) - if diffusion_forcing: latent_size = 4 + _, _, latent_size = get_model_min_frames_and_step(model_type) original_image_refs = image_refs + image_refs = None if image_refs is None else ([] + image_refs) # work on a copy as it is going to be modified # image_refs = None # nb_frames_positions= 0 - frames_to_inject = [] - any_background_ref = False - outpainting_dims = None if video_guide_outpainting== None or len(video_guide_outpainting) == 0 or video_guide_outpainting == "0 0 0 0" or video_guide_outpainting.startswith("#") else [int(v) for v in video_guide_outpainting.split(" ")] # Output Video Ratio Priorities: # Source Video or Start Image > Control Video > Image Ref (background or positioned frames only) > UI Width, Height # Image Ref (non background and non positioned frames) are boxed in a white canvas in order to keep their own width/height ratio + frames_to_inject = [] + any_background_ref = 0 + if "K" in video_prompt_type: + any_background_ref = 2 if model_def.get("all_image_refs_are_background_ref", False) else 1 + + outpainting_dims = get_outpainting_dims(video_guide_outpainting) + fit_canvas = server_config.get("fit_canvas", 0) + fit_crop = fit_canvas == 2 + if fit_crop and outpainting_dims is not None: + fit_crop = False + fit_canvas = 0 - if image_refs is not None and len(image_refs) > 0: - frames_positions_list = [ int(pos)-1 for pos in frames_positions.split(" ")] if frames_positions is not None and len(frames_positions)> 0 else [] - frames_positions_list = frames_positions_list[:len(image_refs)] - nb_frames_positions = len(frames_positions_list) - if nb_frames_positions > 0: - frames_to_inject = [None] * (max(frames_positions_list) + 1) - for i, pos in enumerate(frames_positions_list): - frames_to_inject[pos] = image_refs[i] - if video_guide == None and video_source == None and not "L" in image_prompt_type and (nb_frames_positions > 0 or "K" in video_prompt_type) : - from shared.utils.utils import get_outpainting_full_area_dimensions - w, h = image_refs[0].size - if outpainting_dims != None: - h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) - default_image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) - fit_canvas = None - # if there is a source video and a background image ref, the height/width ratio will need to be processed later by the code for the model (we dont know the source video dimensions at this point) - if len(image_refs) > nb_frames_positions: - any_background_ref = "K" in video_prompt_type - if remove_background_images_ref > 0: - send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) - os.environ["U2NET_HOME"] = os.path.join(os.getcwd(), "ckpts", "rembg") - from shared.utils.utils import resize_and_remove_background - image_refs[nb_frames_positions:] = resize_and_remove_background(image_refs[nb_frames_positions:] , width, height, remove_background_images_ref > 0, any_background_ref, fit_into_canvas= not (any_background_ref or vace or standin) ) # no fit for vace ref images as it is done later - update_task_thumbnails(task, locals()) - send_cmd("output") joint_pass = boost ==1 #and profile != 1 and profile != 3 skip_steps_cache = None if len(skip_steps_cache_type) == 0 else DynamicClass(cache_type = skip_steps_cache_type) @@ -4551,33 +4792,51 @@ def remove_temp_filenames(temp_filenames_list): raise Exception(f"unknown cache type {skip_steps_cache_type}") trans.cache = skip_steps_cache if trans2 is not None: trans2.cache = skip_steps_cache - + face_arc_embeds = None output_new_audio_data = None output_new_audio_filepath = None original_audio_guide = audio_guide + original_audio_guide2 = audio_guide2 audio_proj_split = None audio_proj_full = None audio_scale = None audio_context_lens = None if (fantasy or multitalk or hunyuan_avatar or hunyuan_custom_audio) and audio_guide != None: from models.wan.fantasytalking.infer import parse_audio + from preprocessing.extract_vocals import get_vocals import librosa duration = librosa.get_duration(path=audio_guide) combination_type = "add" + clean_audio_files = "V" in audio_prompt_type if audio_guide2 is not None: duration2 = librosa.get_duration(path=audio_guide2) if "C" in audio_prompt_type: duration += duration2 else: duration = min(duration, duration2) combination_type = "para" if "P" in audio_prompt_type else "add" + if clean_audio_files: + audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav")) + audio_guide2 = get_vocals(original_audio_guide2, get_available_filename(save_path, audio_guide2, "_clean2", ".wav")) + temp_filenames_list += [audio_guide, audio_guide2] else: if "X" in audio_prompt_type: + # dual speaker, voice separation from preprocessing.speakers_separator import extract_dual_audio combination_type = "para" if args.save_speakers: audio_guide, audio_guide2 = "speaker1.wav", "speaker2.wav" else: audio_guide, audio_guide2 = get_available_filename(save_path, audio_guide, "_tmp1", ".wav"), get_available_filename(save_path, audio_guide, "_tmp2", ".wav") - extract_dual_audio(original_audio_guide, audio_guide, audio_guide2 ) + temp_filenames_list += [audio_guide, audio_guide2] + if clean_audio_files: + clean_audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, original_audio_guide, "_clean", ".wav")) + temp_filenames_list += [clean_audio_guide] + extract_dual_audio(clean_audio_guide if clean_audio_files else original_audio_guide, audio_guide, audio_guide2) + + elif clean_audio_files: + # Single Speaker + audio_guide = get_vocals(original_audio_guide, get_available_filename(save_path, audio_guide, "_clean", ".wav")) + temp_filenames_list += [audio_guide] + output_new_audio_filepath = original_audio_guide current_video_length = min(int(fps * duration //latent_size) * latent_size + latent_size + 1, current_video_length) @@ -4589,10 +4848,10 @@ def remove_temp_filenames(temp_filenames_list): # pad audio_proj_full if aligned to beginning of window to simulate source window overlap min_audio_duration = current_video_length/fps if reset_control_aligment else video_source_duration + current_video_length/fps audio_proj_full, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = audio_guide, audio_guide2= audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration) - if output_new_audio_data is not None: output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined - if not args.save_speakers and "X" in audio_prompt_type: - os.remove(audio_guide) - os.remove(audio_guide2) + if output_new_audio_data is not None: # not none if modified + if clean_audio_files: # need to rebuild the sum of audios with original audio + _, output_new_audio_data = get_full_audio_embeddings(audio_guide1 = original_audio_guide, audio_guide2= original_audio_guide2, combination_type= combination_type , num_frames= max_source_video_frames, sr= audio_sampling_rate, fps =fps, padded_frames_for_embeddings = (reuse_frames if reset_control_aligment else 0), min_audio_duration = min_audio_duration, return_sum_only= True) + output_new_audio_filepath= None # need to build original speaker track if it changed size (due to padding at the end) or if it has been combined if hunyuan_custom_edit and video_guide != None: import cv2 @@ -4600,6 +4859,7 @@ def remove_temp_filenames(temp_filenames_list): length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) current_video_length = min(current_video_length, length) + seed = set_seed(seed) torch.set_grad_enabled(False) @@ -4616,9 +4876,9 @@ def remove_temp_filenames(temp_filenames_list): repeat_no = 0 extra_generation = 0 initial_total_windows = 0 - discard_last_frames = sliding_window_discard_last_frames default_requested_frames_to_generate = current_video_length + nb_frames_positions = 0 if sliding_window: initial_total_windows= compute_sliding_window_no(default_requested_frames_to_generate, sliding_window_size, discard_last_frames, reuse_frames) current_video_length = sliding_window_size @@ -4637,7 +4897,7 @@ def remove_temp_filenames(temp_filenames_list): if repeat_no >= total_generation: break repeat_no +=1 gen["repeat_no"] = repeat_no - src_video, src_mask, src_ref_images = None, None, None + src_video = src_video2 = src_mask = src_mask2 = src_faces = src_ref_images = src_ref_masks = sparse_video_image = None prefix_video = pre_video_frame = None source_video_overlap_frames_count = 0 # number of frames overalapped in source video for first window source_video_frames_count = 0 # number of frames to use in source video (processing starts source_video_overlap_frames_count frames before ) @@ -4646,6 +4906,7 @@ def remove_temp_filenames(temp_filenames_list): context_scale = None window_no = 0 extra_windows = 0 + previous_video_path = None # Track previous video for cleanup when sliding_window_keep_only_longest is enabled guide_start_frame = 0 # pos of of first control video frame of current window (reuse_frames later than the first processed frame) keep_frames_parsed = [] # aligned to the first control frame of current window (therefore ignore previous reuse_frames) pre_video_guide = None # reuse_frames of previous window @@ -4670,8 +4931,7 @@ def remove_temp_filenames(temp_filenames_list): while not abort: enable_RIFLEx = RIFLEx_setting == 0 and current_video_length > (6* get_model_fps(base_model_type)+1) or RIFLEx_setting == 1 - if sliding_window: - prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] + prompt = prompts[window_no] if window_no < len(prompts) else prompts[-1] new_extra_windows = gen.get("extra_windows",0) gen["extra_windows"] = 0 extra_windows += new_extra_windows @@ -4692,25 +4952,22 @@ def remove_temp_filenames(temp_filenames_list): return_latent_slice = None if reuse_frames > 0: return_latent_slice = slice(-(reuse_frames - 1 + discard_last_frames ) // latent_size - 1, None if discard_last_frames == 0 else -(discard_last_frames // latent_size) ) - refresh_preview = {"image_guide" : None, "image_mask" : None} + refresh_preview = {"image_guide" : image_guide, "image_mask" : image_mask} if image_mode >= 1 else {} - src_ref_images = image_refs image_start_tensor = image_end_tensor = None if window_no == 1 and (video_source is not None or image_start is not None): if image_start is not None: - new_height, new_width = calculate_new_dimensions(height, width, image_start.height, image_start.width, sample_fit_canvas, block_size = block_size) - image_start_tensor = image_start.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + image_start_tensor, new_height, new_width = calculate_dimensions_and_resize_image(image_start, height, width, sample_fit_canvas, fit_crop, block_size = block_size) + if fit_crop: refresh_preview["image_start"] = image_start_tensor image_start_tensor = convert_image_to_tensor(image_start_tensor) pre_video_guide = prefix_video = image_start_tensor.unsqueeze(1) - if image_end is not None: - image_end_tensor = image_end.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - image_end_tensor = convert_image_to_tensor(image_end_tensor) else: - if "L" in image_prompt_type: - refresh_preview["video_source"] = get_video_frame(video_source, 0) - prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, target_fps = fps, block_size = block_size ) + prefix_video = preprocess_video(width=width, height=height,video_in=video_source, max_frames= parsed_keep_frames_video_source , start_frame = 0, fit_canvas= sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, block_size = block_size ) prefix_video = prefix_video.permute(3, 0, 1, 2) prefix_video = prefix_video.float().div_(127.5).sub_(1.) # c, f, h, w + if fit_crop or "L" in image_prompt_type: refresh_preview["video_source"] = convert_tensor_to_image(prefix_video, 0) + + new_height, new_width = prefix_video.shape[-2:] pre_video_guide = prefix_video[:, -reuse_frames:] pre_video_frame = convert_tensor_to_image(prefix_video[:, -1]) source_video_overlap_frames_count = pre_video_guide.shape[1] @@ -4719,7 +4976,15 @@ def remove_temp_filenames(temp_filenames_list): image_size = pre_video_guide.shape[-2:] sample_fit_canvas = None guide_start_frame = prefix_video.shape[1] - + if image_end is not None: + image_end_list= image_end if isinstance(image_end, list) else [image_end] + if len(image_end_list) >= window_no: + new_height, new_width = image_size + image_end_tensor, _, _ = calculate_dimensions_and_resize_image(image_end_list[window_no-1], new_height, new_width, sample_fit_canvas, fit_crop, block_size = block_size) + # image_end_tensor =image_end_list[window_no-1].resize((new_width, new_height), resample=Image.Resampling.LANCZOS) + refresh_preview["image_end"] = image_end_tensor + image_end_tensor = convert_image_to_tensor(image_end_tensor) + image_end_list= None window_start_frame = guide_start_frame - (reuse_frames if window_no > 1 else source_video_overlap_frames_count) guide_end_frame = guide_start_frame + current_video_length - (source_video_overlap_frames_count if window_no == 1 else reuse_frames) alignment_shift = source_video_frames_count if reset_control_aligment else 0 @@ -4733,118 +4998,178 @@ def remove_temp_filenames(temp_filenames_list): # special treatment for start frame pos when alignement to first frame requested as otherwise the start frame number will be negative due to overlapped frames (has been previously compensated later with padding) audio_proj_split = get_window_audio_embeddings(audio_proj_full, audio_start_idx= aligned_window_start_frame + (source_video_overlap_frames_count if reset_control_aligment else 0 ), clip_length = current_video_length) + if repeat_no == 1 and window_no == 1 and image_refs is not None and len(image_refs) > 0: + frames_positions_list = [] + if frames_positions is not None and len(frames_positions)> 0: + positions = frames_positions.replace(","," ").split(" ") + cur_end_pos = -1 + (source_video_frames_count - source_video_overlap_frames_count) + last_frame_no = requested_frames_to_generate + source_video_frames_count - source_video_overlap_frames_count + joker_used = False + project_window_no = 1 + for pos in positions : + if len(pos) > 0: + if pos in ["L", "l"]: + cur_end_pos += sliding_window_size if project_window_no > 1 else current_video_length + if cur_end_pos >= last_frame_no-1 and not joker_used: + joker_used = True + cur_end_pos = last_frame_no -1 + project_window_no += 1 + frames_positions_list.append(cur_end_pos) + cur_end_pos -= sliding_window_discard_last_frames + reuse_frames + else: + frames_positions_list.append(int(pos)-1 + alignment_shift) + frames_positions_list = frames_positions_list[:len(image_refs)] + nb_frames_positions = len(frames_positions_list) + if nb_frames_positions > 0: + frames_to_inject = [None] * (max(frames_positions_list) + 1) + for i, pos in enumerate(frames_positions_list): + frames_to_inject[pos] = image_refs[i] + + + video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = sparse_video_image = None if video_guide is not None: - keep_frames_parsed, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) + keep_frames_parsed_full, error = parse_keep_frames_video_guide(keep_frames_video_guide, source_video_frames_count -source_video_overlap_frames_count + requested_frames_to_generate) if len(error) > 0: raise gr.Error(f"invalid keep frames {keep_frames_video_guide}") - keep_frames_parsed = keep_frames_parsed[aligned_guide_start_frame: aligned_guide_end_frame ] - if infinitetalk and video_guide is not None: - src_image = get_video_frame(video_guide, aligned_guide_start_frame-1, return_last_if_missing = True, return_PIL = True) - new_height, new_width = calculate_new_dimensions(image_size[0], image_size[1], src_image.height, src_image.width, sample_fit_canvas, block_size = block_size) - src_image = src_image.resize((new_width, new_height), resample=Image.Resampling.LANCZOS) - refresh_preview["video_guide"] = src_image - src_video = convert_image_to_tensor(src_image).unsqueeze(1) - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - if ltxv and video_guide is not None: - preprocess_type = process_map_video_guide.get(filter_letters(video_prompt_type, "PED"), "raw") + guide_frames_extract_start = aligned_window_start_frame if extract_guide_from_window_start else aligned_guide_start_frame + extra_control_frames = model_def.get("extra_control_frames", 0) + if extra_control_frames > 0 and aligned_guide_start_frame >= extra_control_frames: guide_frames_extract_start -= extra_control_frames + + keep_frames_parsed = [True] * -guide_frames_extract_start if guide_frames_extract_start <0 else [] + keep_frames_parsed += keep_frames_parsed_full[max(0, guide_frames_extract_start): aligned_guide_end_frame ] + guide_frames_extract_count = len(keep_frames_parsed) + + # Extract Faces to video + if "B" in video_prompt_type: + send_cmd("progress", [0, get_latest_status(state, "Extracting Face Movements")]) + src_faces = extract_faces_from_video_with_mask(video_guide, video_mask, max_frames= guide_frames_extract_count, start_frame= guide_frames_extract_start, size= 512, target_fps = fps) + if src_faces is not None and src_faces.shape[1] < current_video_length: + src_faces = torch.cat([src_faces, torch.full( (3, current_video_length - src_faces.shape[1], 512, 512 ), -1, dtype = src_faces.dtype, device= src_faces.device) ], dim=1) + + # Sparse Video to Video + sparse_video_image = get_video_frame(video_guide, aligned_guide_start_frame, return_last_if_missing = True, target_fps = fps, return_PIL = True) if "R" in video_prompt_type else None + + # Generic Video Preprocessing + process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) + preprocess_type, preprocess_type2 = "raw", None + for process_num, process_letter in enumerate( filter_letters(video_prompt_type, video_guide_processes)): + if process_num == 0: + preprocess_type = process_map_video_guide.get(process_letter, "raw") + else: + preprocess_type2 = process_map_video_guide.get(process_letter, None) status_info = "Extracting " + processes_names[preprocess_type] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - # start one frame ealier to facilitate latents merging later - src_video, _ = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) + (0 if aligned_guide_start_frame == 0 else 1), start_frame = aligned_guide_start_frame - (0 if aligned_guide_start_frame == 0 else 1), fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, inpaint_color = 0, proc_no =1, negate_mask = "N" in video_prompt_type, process_outside_mask = "inpaint" if "X" in video_prompt_type else "identity", block_size =block_size ) - if src_video != None: - src_video = src_video[ :(len(src_video)-1)// latent_size * latent_size +1 ] - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - src_video = src_video.permute(3, 0, 1, 2) - src_video = src_video.float().div_(127.5).sub_(1.) # c, f, h, w - if sample_fit_canvas != None: - image_size = src_video.shape[-2:] - sample_fit_canvas = None - - if t2v and "G" in video_prompt_type: - video_guide_processed = preprocess_video(width = image_size[1], height=image_size[0], video_in=video_guide, max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas= sample_fit_canvas, target_fps = fps) - if video_guide_processed == None: - src_video = pre_video_guide - else: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - src_video = video_guide_processed.float().div_(127.5).sub_(1.).permute(-1,0,1,2) - if pre_video_guide != None: - src_video = torch.cat( [pre_video_guide, src_video], dim=1) - - if vace : - image_refs_copy = image_refs[nb_frames_positions:].copy() if image_refs != None and len(image_refs) > nb_frames_positions else None # required since prepare_source do inplace modifications - context_scale = [ control_net_weight] - video_guide_processed = video_mask_processed = video_guide_processed2 = video_mask_processed2 = None - if "V" in video_prompt_type: - process_outside_mask = process_map_outside_mask.get(filter_letters(video_prompt_type, "YWX"), None) - preprocess_type, preprocess_type2 = "raw", None - for process_num, process_letter in enumerate( filter_letters(video_prompt_type, "PDSLCMU")): - if process_num == 0: - preprocess_type = process_map_video_guide.get(process_letter, "raw") + extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) + if len(extra_process_list) == 1: + status_info += " and " + processes_names[extra_process_list[0]] + elif len(extra_process_list) == 2: + status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] + context_scale = [control_net_weight /2, control_net_weight2 /2] if preprocess_type2 is not None else [control_net_weight] + if not (preprocess_type == "identity" and preprocess_type2 is None and video_mask is None):send_cmd("progress", [0, get_latest_status(state, status_info)]) + inpaint_color = 0 if preprocess_type=="pose" and process_outside_mask == "inpaint" else guide_inpaint_color + video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide if sparse_video_image is None else sparse_video_image, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1, inpaint_color =inpaint_color, block_size = block_size, to_bbox = "H" in video_prompt_type ) + if preprocess_type2 != None: + video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= guide_frames_extract_count, start_frame = guide_frames_extract_start, fit_canvas = sample_fit_canvas, fit_crop = fit_crop, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2, block_size = block_size, to_bbox = "H" in video_prompt_type ) + + if video_guide_processed is not None and sample_fit_canvas is not None: + image_size = video_guide_processed.shape[-2:] + sample_fit_canvas = None + + if window_no == 1 and image_refs is not None and len(image_refs) > 0: + if sample_fit_canvas is not None and (nb_frames_positions > 0 or "K" in video_prompt_type) : + from shared.utils.utils import get_outpainting_full_area_dimensions + w, h = image_refs[0].size + if outpainting_dims != None: + h, w = get_outpainting_full_area_dimensions(h,w, outpainting_dims) + image_size = calculate_new_dimensions(height, width, h, w, fit_canvas) + sample_fit_canvas = None + if repeat_no == 1: + if fit_crop: + if any_background_ref == 2: + end_ref_position = len(image_refs) + elif any_background_ref == 1: + end_ref_position = nb_frames_positions + 1 else: - preprocess_type2 = process_map_video_guide.get(process_letter, None) - status_info = "Extracting " + processes_names[preprocess_type] - extra_process_list = ([] if preprocess_type2==None else [preprocess_type2]) + ([] if process_outside_mask==None or process_outside_mask == preprocess_type else [process_outside_mask]) - if len(extra_process_list) == 1: - status_info += " and " + processes_names[extra_process_list[0]] - elif len(extra_process_list) == 2: - status_info += ", " + processes_names[extra_process_list[0]] + " and " + processes_names[extra_process_list[1]] - if preprocess_type2 is not None: - context_scale = [ control_net_weight /2, control_net_weight2 /2] - send_cmd("progress", [0, get_latest_status(state, status_info)]) - video_guide_processed, video_mask_processed = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed) , start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =1 ) - if preprocess_type2 != None: - video_guide_processed2, video_mask_processed2 = preprocess_video_with_mask(video_guide, video_mask, height=image_size[0], width = image_size[1], max_frames= len(keep_frames_parsed), start_frame = aligned_guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type = preprocess_type2, expand_scale = mask_expand, RGB_Mask = True, negate_mask = "N" in video_prompt_type, process_outside_mask = process_outside_mask, outpainting_dims = outpainting_dims, proc_no =2 ) - - if video_guide_processed != None: - if sample_fit_canvas != None: - image_size = video_guide_processed.shape[-3: -1] - sample_fit_canvas = None - refresh_preview["video_guide"] = Image.fromarray(video_guide_processed[0].cpu().numpy()) - if video_guide_processed2 != None: - refresh_preview["video_guide"] = [refresh_preview["video_guide"], Image.fromarray(video_guide_processed2[0].cpu().numpy())] - if video_mask_processed != None: - refresh_preview["video_mask"] = Image.fromarray(video_mask_processed[0].cpu().numpy()) - frames_to_inject_parsed = frames_to_inject[aligned_guide_start_frame: aligned_guide_end_frame] - - src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide_processed] if video_guide_processed2 == None else [video_guide_processed, video_guide_processed2], - [video_mask_processed] if video_guide_processed2 == None else [video_mask_processed, video_mask_processed2], - [image_refs_copy] if video_guide_processed2 == None else [image_refs_copy, image_refs_copy], - current_video_length, image_size = image_size, device ="cpu", - keep_video_guide_frames=keep_frames_parsed, - start_frame = aligned_guide_start_frame, - pre_src_video = [pre_video_guide] if video_guide_processed2 == None else [pre_video_guide, pre_video_guide], - fit_into_canvas = sample_fit_canvas, - inject_frames= frames_to_inject_parsed, - outpainting_dims = outpainting_dims, - any_background_ref = any_background_ref - ) - if len(frames_to_inject_parsed) or any_background_ref: - new_image_refs = [convert_tensor_to_image(src_video[0], frame_no) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] - if any_background_ref: - new_image_refs += [convert_tensor_to_image(image_refs_copy[0], 0)] + image_refs[nb_frames_positions+1:] + end_ref_position = nb_frames_positions + for i, img in enumerate(image_refs[:end_ref_position]): + image_refs[i] = rescale_and_crop(img, default_image_size[1], default_image_size[0]) + refresh_preview["image_refs"] = image_refs + + if len(image_refs) > nb_frames_positions: + src_ref_images = image_refs[nb_frames_positions:] + if "Q" in video_prompt_type: + from preprocessing.arc.face_encoder import FaceEncoderArcFace, get_landmarks_from_image + image_pil = src_ref_images[-1] + face_encoder = FaceEncoderArcFace() + face_encoder.init_encoder_model(processing_device) + face_arc_embeds = face_encoder(image_pil, need_proc=True, landmarks=get_landmarks_from_image(image_pil)) + face_arc_embeds = face_arc_embeds.squeeze(0).cpu() + face_encoder = image_pil = None + gc.collect() + torch.cuda.empty_cache() + + if remove_background_images_ref > 0: + send_cmd("progress", [0, get_latest_status(state, "Removing Images References Background")]) + + src_ref_images, src_ref_masks = resize_and_remove_background(src_ref_images , image_size[1], image_size[0], + remove_background_images_ref > 0, any_background_ref, + fit_into_canvas= model_def.get("fit_into_canvas_image_refs", 1), + block_size=block_size, + outpainting_dims =outpainting_dims, + background_ref_outpainted = model_def.get("background_ref_outpainted", True), + return_tensor= model_def.get("return_image_refs_tensor", False) ) + + frames_to_inject_parsed = frames_to_inject[ window_start_frame if extract_guide_from_window_start else guide_start_frame: guide_end_frame] + if video_guide is not None or len(frames_to_inject_parsed) > 0 or model_def.get("forced_guide_mask_inputs", False): + any_mask = video_mask is not None or model_def.get("forced_guide_mask_inputs", False) + any_guide_padding = model_def.get("pad_guide_video", False) + from shared.utils.utils import prepare_video_guide_and_mask + src_videos, src_masks = prepare_video_guide_and_mask( [video_guide_processed] + ([] if video_guide_processed2 is None else [video_guide_processed2]), + [video_mask_processed] + ([] if video_guide_processed2 is None else [video_mask_processed2]), + None if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else pre_video_guide, + image_size, current_video_length, latent_size, + any_mask, any_guide_padding, guide_inpaint_color, + keep_frames_parsed, frames_to_inject_parsed , outpainting_dims) + video_guide_processed = video_guide_processed2 = video_mask_processed = video_mask_processed2 = None + if len(src_videos) == 1: + src_video, src_video2, src_mask, src_mask2 = src_videos[0], None, src_masks[0], None + else: + src_video, src_video2 = src_videos + src_mask, src_mask2 = src_masks + src_videos = src_masks = None + if src_video is None: + abort = True + break + if src_faces is not None: + if src_faces.shape[1] < src_video.shape[1]: + src_faces = torch.concat( [src_faces, src_faces[:, -1:].repeat(1, src_video.shape[1] - src_faces.shape[1], 1,1)], dim =1) else: - new_image_refs += image_refs[nb_frames_positions:] - refresh_preview["image_refs"] = new_image_refs - new_image_refs = None - - if sample_fit_canvas != None: - image_size = src_video[0].shape[-2:] - sample_fit_canvas = None - elif hunyuan_custom_edit: - if "P" in video_prompt_type: - progress_args = [0, get_latest_status(state,"Extracting Open Pose Information and Expanding Mask")] + src_faces = src_faces[:, :src_video.shape[1]] + if video_guide is not None or len(frames_to_inject_parsed) > 0: + if args.save_masks: + if src_video is not None: + save_video( src_video, "masked_frames.mp4", fps) + if any_mask: save_video( src_mask, "masks.mp4", fps, value_range=(0, 1)) + if src_video2 is not None: + save_video( src_video2, "masked_frames2.mp4", fps) + if any_mask: save_video( src_mask2, "masks2.mp4", fps, value_range=(0, 1)) + if video_guide is not None: + preview_frame_no = 0 if extract_guide_from_window_start or model_def.get("dont_cat_preguide", False) or sparse_video_image is not None else (guide_start_frame - window_start_frame) + refresh_preview["video_guide"] = convert_tensor_to_image(src_video, preview_frame_no) + if src_video2 is not None: + refresh_preview["video_guide"] = [refresh_preview["video_guide"], convert_tensor_to_image(src_video2, preview_frame_no)] + if src_mask is not None and video_mask is not None: + refresh_preview["video_mask"] = convert_tensor_to_image(src_mask, preview_frame_no, mask_levels = True) + + if src_ref_images is not None or nb_frames_positions: + if len(frames_to_inject_parsed): + new_image_refs = [convert_tensor_to_image(src_video, frame_no + (0 if extract_guide_from_window_start else (aligned_guide_start_frame - aligned_window_start_frame)) ) for frame_no, inject in enumerate(frames_to_inject_parsed) if inject] else: - progress_args = [0, get_latest_status(state,"Extracting Video and Mask")] + new_image_refs = [] + if src_ref_images is not None: + new_image_refs += [convert_tensor_to_image(img) if torch.is_tensor(img) else img for img in src_ref_images ] + refresh_preview["image_refs"] = new_image_refs + new_image_refs = None - send_cmd("progress", progress_args) - src_video, src_mask = preprocess_video_with_mask(video_guide, video_mask, height=height, width = width, max_frames= current_video_length if window_no == 1 else current_video_length - reuse_frames, start_frame = guide_start_frame, fit_canvas = sample_fit_canvas, target_fps = fps, process_type= "pose" if "P" in video_prompt_type else "inpaint", negate_mask = "N" in video_prompt_type, inpaint_color =0) - refresh_preview["video_guide"] = Image.fromarray(src_video[0].cpu().numpy()) - if src_mask != None: - refresh_preview["video_mask"] = Image.fromarray(src_mask[0].cpu().numpy()) if len(refresh_preview) > 0: new_inputs= locals() new_inputs.update(refresh_preview) @@ -4881,17 +5206,21 @@ def set_header_text(txt): input_prompt = prompt, image_start = image_start_tensor, image_end = image_end_tensor, - input_frames = src_video, + input_frames = src_video, + input_frames2 = src_video2, input_ref_images= src_ref_images, + input_ref_masks = src_ref_masks, input_masks = src_mask, + input_masks2 = src_mask2, input_video= pre_video_guide, + input_faces = src_faces, denoising_strength=denoising_strength, prefix_frames_count = source_video_overlap_frames_count if window_no <= 1 else reuse_frames, frame_num= (current_video_length // latent_size)* latent_size + 1, batch_size = batch_size, - height = height, - width = width, - fit_into_canvas = fit_canvas == 1, + height = image_size[0], + width = image_size[1], + fit_into_canvas = fit_canvas, shift=flow_shift, sample_solver=sample_solver, sampling_steps=num_inference_steps, @@ -4922,6 +5251,7 @@ def set_header_text(txt): audio_scale= audio_scale, audio_context_lens= audio_context_lens, context_scale = context_scale, + control_scale_alt = control_net_weight_alt, model_mode = model_mode, causal_block_size = 5, causal_attention = True, @@ -4948,6 +5278,8 @@ def set_header_text(txt): pre_video_frame = pre_video_frame, original_input_ref_images = original_image_refs[nb_frames_positions:] if original_image_refs is not None else [], image_refs_relative_size = image_refs_relative_size, + outpainting_dims = outpainting_dims, + face_arc_embeds = face_arc_embeds, ) except Exception as e: if len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0: @@ -5015,6 +5347,7 @@ def set_header_text(txt): send_cmd("output") else: sample = samples.cpu() + abort = not is_image and sample.shape[1] < current_video_length # if True: # for testing # torch.save(sample, "output.pt") # else: @@ -5087,6 +5420,11 @@ def set_header_text(txt): video_path= new_image_path elif len(control_audio_tracks) > 0 or len(source_audio_tracks) > 0 or output_new_audio_filepath is not None or any_mmaudio or output_new_audio_data is not None or audio_source is not None: video_path = os.path.join(save_path, file_name) + + # Delete previous video if sliding_window_keep_only_longest is enabled + cleanup_previous_video(previous_video_path, file_list, file_settings_list, lock, + sliding_window, sliding_window_keep_only_longest) + save_path_tmp = video_path[:-4] + "_tmp.mp4" save_video( tensor=sample[None], save_file=save_path_tmp, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type = server_config.get("video_output_codec", None)) output_new_audio_temp_filepath = None @@ -5115,6 +5453,10 @@ def set_header_text(txt): if output_new_audio_temp_filepath is not None: os.remove(output_new_audio_temp_filepath) else: + # Delete previous video if sliding_window_keep_only_longest is enabled + cleanup_previous_video(previous_video_path, file_list, file_settings_list, lock, + sliding_window, sliding_window_keep_only_longest) + save_video( tensor=sample[None], save_file=video_path, fps=output_fps, nrow=1, normalize=True, value_range=(-1, 1), codec_type= server_config.get("video_output_codec", None)) end_time = time.time() @@ -5165,6 +5507,14 @@ def set_header_text(txt): with lock: file_list.append(path) file_settings_list.append(configs if no > 0 else configs.copy()) + + # Update previous video path for cleanup in next iteration (only for videos, not images) + if should_cleanup_video(sliding_window, sliding_window_keep_only_longest, is_image): + if isinstance(video_path, list): + # Handle case where video_path might be a list (for images) + previous_video_path = video_path[0] if len(video_path) > 0 else None + else: + previous_video_path = video_path # Play notification sound for single video try: @@ -5272,7 +5622,7 @@ def process_tasks(state): while True: with gen_lock: process_status = gen.get("process_status", None) - if process_status is None: + if process_status is None or process_status == "process:main": gen["process_status"] = "process:main" break time.sleep(1) @@ -5889,6 +6239,8 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if target == "settings": return inputs + image_outputs = inputs.get("image_mode",0) > 0 + pop=[] if "force_fps" in inputs and len(inputs["force_fps"])== 0: pop += ["force_fps"] @@ -5903,13 +6255,13 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["MMAudio_setting", "MMAudio_prompt", "MMAudio_neg_prompt"] video_prompt_type = inputs["video_prompt_type"] - if not base_model_type in ["t2v"]: + if not "G" in video_prompt_type: pop += ["denoising_strength"] if not (server_config.get("enhancer_enabled", 0) > 0 and server_config.get("enhancer_mode", 0) == 0): pop += ["prompt_enhancer"] - if not recammaster and not diffusion_forcing and not flux: + if model_def.get("model_modes", None) is None: pop += ["model_mode"] if not vace and not phantom and not hunyuan_video_custom: @@ -5923,8 +6275,14 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None pop += ["image_refs_relative_size"] if not vace: - pop += ["frames_positions", "video_guide_outpainting", "control_net_weight", "control_net_weight2"] - + pop += ["frames_positions", "control_net_weight", "control_net_weight2"] + + if not len(model_def.get("control_net_weight_alt_name", "")) >0: + pop += ["control_net_alt"] + + if model_def.get("video_guide_outpainting", None) is None: + pop += ["video_guide_outpainting"] + if not (vace or t2v): pop += ["min_frames_if_references"] @@ -5954,7 +6312,7 @@ def prepare_inputs_dict(target, inputs, model_type = None, model_filename = None if guidance_max_phases < 3 or guidance_phases < 3: pop += ["guidance3_scale", "switch_threshold2", "model_switch_phase"] - if ltxv: + if ltxv or image_outputs: pop += ["flow_shift"] if model_def.get("no_negative_prompt", False) : @@ -6013,10 +6371,16 @@ def video_to_source_video(state, input_file_list, choice): def image_to_ref_image_add(state, input_file_list, choice, target, target_name): file_list, file_settings_list = get_file_list(state, input_file_list) if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update() - gr.Info(f"Selected Image was added to {target_name}") - if target == None: - target =[] - target.append( file_list[choice]) + model_type = state["model_type"] + model_def = get_model_def(model_type) + if model_def.get("one_image_ref_needed", False): + gr.Info(f"Selected Image was set to {target_name}") + target =[file_list[choice]] + else: + gr.Info(f"Selected Image was added to {target_name}") + if target == None: + target =[] + target.append( file_list[choice]) return target def image_to_ref_image_set(state, input_file_list, choice, target, target_name): @@ -6025,6 +6389,18 @@ def image_to_ref_image_set(state, input_file_list, choice, target, target_name): gr.Info(f"Selected Image was copied to {target_name}") return file_list[choice] +def image_to_ref_image_guide(state, input_file_list, choice): + file_list, file_settings_list = get_file_list(state, input_file_list) + if len(file_list) == 0 or choice == None or choice < 0 or choice > len(file_list): return gr.update(), gr.update() + ui_settings = get_current_model_settings(state) + gr.Info(f"Selected Image was copied to Control Image") + new_image = file_list[choice] + if ui_settings["image_mode"]==2 or True: + return new_image, new_image + else: + return new_image, None + + def apply_post_processing(state, input_file_list, choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation): gen = get_gen_info(state) @@ -6091,14 +6467,6 @@ def eject_video_from_gallery(state, input_file_list, choice): choice = min(choice, len(file_list)) return gr.Gallery(value = file_list, selected_index= choice), gr.update() if len(file_list) >0 else get_default_video_info(), gr.Row(visible= len(file_list) > 0) -def has_video_file_extension(filename): - extension = os.path.splitext(filename)[-1] - return extension in [".mp4"] - -def has_image_file_extension(filename): - extension = os.path.splitext(filename)[-1] - return extension in [".jpeg", ".jpg", ".png", ".webp", ".bmp", ".tiff"] - def add_videos_to_gallery(state, input_file_list, choice, files_to_load): gen = get_gen_info(state) if files_to_load == None: @@ -6229,6 +6597,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw if not "WanGP" in configs.get("type", ""): configs = None except: configs = None + if configs is None: return None, False current_model_filename = state["model_filename"] @@ -6258,7 +6627,7 @@ def get_settings_from_file(state, file_path, allow_json, merge_with_defaults, sw return configs, any_image_or_video def record_image_mode_tab(state, evt:gr.SelectData): - state["image_mode_tab"] = 0 if evt.index ==0 else 1 + state["image_mode_tab"] = evt.index def switch_image_mode(state): image_mode = state.get("image_mode_tab", 0) @@ -6266,19 +6635,30 @@ def switch_image_mode(state): ui_defaults = get_model_settings(state, model_type) ui_defaults["image_mode"] = image_mode - + video_prompt_type = ui_defaults.get("video_prompt_type", "") + model_def = get_model_def( model_type) + inpaint_support = model_def.get("inpaint_support", False) + if inpaint_support: + if image_mode == 1: + video_prompt_type = del_in_sequence(video_prompt_type, "VAG" + all_guide_processes) + video_prompt_type = add_to_sequence(video_prompt_type, "KI") + elif image_mode == 2: + video_prompt_type = del_in_sequence(video_prompt_type, "KI" + all_guide_processes) + video_prompt_type = add_to_sequence(video_prompt_type, "VAG") + ui_defaults["video_prompt_type"] = video_prompt_type + return str(time.time()) def load_settings_from_file(state, file_path): gen = get_gen_info(state) if file_path==None: - return gr.update(), gr.update(), None + return gr.update(), gr.update(), None, gr.update() configs, any_video_or_image_file = get_settings_from_file(state, file_path, True, True, True) if configs == None: gr.Info("File not supported") - return gr.update(), gr.update(), None + return gr.update(), gr.update(), None, gr.update() current_model_type = state["model_type"] model_type = configs["model_type"] @@ -6292,13 +6672,14 @@ def load_settings_from_file(state, file_path): if model_type == current_model_type: set_model_settings(state, current_model_type, configs) - return gr.update(), gr.update(), str(time.time()), None + return gr.update(), gr.update(), str(time.time()), gr.update() else: set_model_settings(state, model_type, configs) - return *generate_dropdown_model_list(model_type), gr.update(), None + return *generate_dropdown_model_list(model_type), gr.update(), gr.update() def save_inputs( target, + image_mask_guide, lset_name, image_mode, prompt, @@ -6346,6 +6727,7 @@ def save_inputs( image_mask, control_net_weight, control_net_weight2, + control_net_weight_alt, mask_expand, audio_guide, audio_guide2, @@ -6384,13 +6766,19 @@ def save_inputs( state, ): - - # if state.get("validate_success",0) != 1: - # return + model_filename = state["model_filename"] model_type = state["model_type"] + if image_mask_guide is not None and image_mode >= 1 and video_prompt_type is not None and "A" in video_prompt_type and not "U" in video_prompt_type: + # if image_mask_guide is not None and image_mode == 2: + if "background" in image_mask_guide: + image_guide = image_mask_guide["background"] + if "layers" in image_mask_guide and len(image_mask_guide["layers"])>0: + image_mask = image_mask_guide["layers"][0] + image_mask_guide = None inputs = get_function_arguments(save_inputs, locals()) inputs.pop("target") + inputs.pop("image_mask_guide") cleaned_inputs = prepare_inputs_dict(target, inputs) if target == "settings": defaults_filename = get_settings_file_name(model_type) @@ -6494,11 +6882,16 @@ def change_model(state, model_choice): return header -def fill_inputs(state): +def get_current_model_settings(state): model_type = state["model_type"] - ui_defaults = get_model_settings(state, model_type) + ui_defaults = get_model_settings(state, model_type) if ui_defaults == None: ui_defaults = get_default_settings(model_type) + set_model_settings(state, model_type, ui_defaults) + return ui_defaults + +def fill_inputs(state): + ui_defaults = get_current_model_settings(state) return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults) @@ -6536,11 +6929,13 @@ def any_letters(source_str, letters): return True return False -def filter_letters(source_str, letters): +def filter_letters(source_str, letters, default= ""): ret = "" for letter in letters: if letter in source_str: ret += letter + if len(ret) == 0: + return default return ret def add_to_sequence(source_str, letters): @@ -6562,16 +6957,34 @@ def refresh_audio_prompt_type_remux(state, audio_prompt_type, remux): audio_prompt_type = add_to_sequence(audio_prompt_type, remux) return audio_prompt_type +def refresh_remove_background_sound(state, audio_prompt_type, remove_background_sound): + audio_prompt_type = del_in_sequence(audio_prompt_type, "V") + if remove_background_sound: + audio_prompt_type = add_to_sequence(audio_prompt_type, "V") + return audio_prompt_type + + def refresh_audio_prompt_type_sources(state, audio_prompt_type, audio_prompt_type_sources): audio_prompt_type = del_in_sequence(audio_prompt_type, "XCPAB") audio_prompt_type = add_to_sequence(audio_prompt_type, audio_prompt_type_sources) - return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)) - -def refresh_image_prompt_type(state, image_prompt_type): - any_video_source = len(filter_letters(image_prompt_type, "VLG"))>0 - return gr.update(visible = "S" in image_prompt_type ), gr.update(visible = "E" in image_prompt_type ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source) - -def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs): + return audio_prompt_type, gr.update(visible = "A" in audio_prompt_type), gr.update(visible = "B" in audio_prompt_type), gr.update(visible = ("B" in audio_prompt_type or "X" in audio_prompt_type)), gr.update(visible= any_letters(audio_prompt_type, "ABX")) + +def refresh_image_prompt_type_radio(state, image_prompt_type, image_prompt_type_radio): + image_prompt_type = del_in_sequence(image_prompt_type, "VLTS") + image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) + any_video_source = len(filter_letters(image_prompt_type, "VL"))>0 + model_def = get_model_def(state["model_type"]) + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") + end_visible = "E" in image_prompt_types_allowed and any_letters(image_prompt_type, "SVL") + return image_prompt_type, gr.update(visible = "S" in image_prompt_type ), gr.update(visible = end_visible and ("E" in image_prompt_type) ), gr.update(visible = "V" in image_prompt_type) , gr.update(visible = any_video_source), gr.update(visible = end_visible) + +def refresh_image_prompt_type_endcheckbox(state, image_prompt_type, image_prompt_type_radio, end_checkbox): + image_prompt_type = del_in_sequence(image_prompt_type, "E") + if end_checkbox: image_prompt_type += "E" + image_prompt_type = add_to_sequence(image_prompt_type, image_prompt_type_radio) + return image_prompt_type, gr.update(visible = "E" in image_prompt_type ) + +def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_type_image_refs, image_mode): model_type = state["model_type"] model_def = get_model_def(model_type) image_ref_choices = model_def.get("image_ref_choices", None) @@ -6581,49 +6994,128 @@ def refresh_video_prompt_type_image_refs(state, video_prompt_type, video_prompt_ video_prompt_type = del_in_sequence(video_prompt_type, "KFI") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_image_refs) visible = "I" in video_prompt_type - vace= test_vace_module(state["model_type"]) - + any_outpainting= image_mode in model_def.get("video_guide_outpainting", []) rm_bg_visible= visible and not model_def.get("no_background_removal", False) img_rel_size_visible = visible and model_def.get("any_image_refs_relative_size", False) - return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and vace ) + return video_prompt_type, gr.update(visible = visible),gr.update(visible = rm_bg_visible), gr.update(visible = img_rel_size_visible), gr.update(visible = visible and "F" in video_prompt_type_image_refs), gr.update(visible= ("F" in video_prompt_type_image_refs or "K" in video_prompt_type_image_refs or "V" in video_prompt_type) and any_outpainting ) + +def switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): + if image_mode == 0: return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) + mask_in_old = "A" in old_video_prompt_type and not "U" in old_video_prompt_type + mask_in_new = "A" in video_prompt_type and not "U" in video_prompt_type + image_mask_guide_value, image_mask_value, image_guide_value = {}, {}, {} + visible = "V" in video_prompt_type + if mask_in_old != mask_in_new: + if mask_in_new: + if old_image_mask_value is None: + image_mask_guide_value["value"] = old_image_guide_value + else: + image_mask_guide_value["value"] = {"background" : old_image_guide_value, "composite" : None, "layers": [rgb_bw_to_rgba_mask(old_image_mask_value)]} + image_guide_value["value"] = image_mask_value["value"] = None + else: + if old_image_mask_guide_value is not None and "background" in old_image_mask_guide_value: + image_guide_value["value"] = old_image_mask_guide_value["background"] + if "layers" in old_image_mask_guide_value: + image_mask_value["value"] = old_image_mask_guide_value["layers"][0] if len(old_image_mask_guide_value["layers"]) >=1 else None + image_mask_guide_value["value"] = {"background" : None, "composite" : None, "layers": []} + + image_mask_guide = gr.update(visible= visible and mask_in_new, **image_mask_guide_value) + image_guide = gr.update(visible = visible and not mask_in_new, **image_guide_value) + image_mask = gr.update(visible = False, **image_mask_value) + return image_mask_guide, image_guide, image_mask -def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode): +def refresh_video_prompt_type_video_mask(state, video_prompt_type, video_prompt_type_video_mask, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): + old_video_prompt_type = video_prompt_type video_prompt_type = del_in_sequence(video_prompt_type, "XYZWNA") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_mask) visible= "A" in video_prompt_type model_type = state["model_type"] model_def = get_model_def(model_type) - image_outputs = image_mode == 1 - return video_prompt_type, gr.update(visible= visible and not image_outputs), gr.update(visible= visible and image_outputs), gr.update(visible= visible ) + image_outputs = image_mode > 0 + image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) + return video_prompt_type, gr.update(visible= visible and not image_outputs), image_mask_guide, image_guide, image_mask, gr.update(visible= visible ) def refresh_video_prompt_type_alignment(state, video_prompt_type, video_prompt_type_video_guide): video_prompt_type = del_in_sequence(video_prompt_type, "T") video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) return video_prompt_type -def refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide, image_mode): - video_prompt_type = del_in_sequence(video_prompt_type, "PDESLCMGUV") +all_guide_processes ="PDESLCMUVBH" +video_guide_processes = "PEDSLCMU" + +def refresh_video_prompt_type_video_guide(state, filter_type, video_prompt_type, video_prompt_type_video_guide, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): + model_type = state["model_type"] + model_def = get_model_def(model_type) + old_video_prompt_type = video_prompt_type + if filter_type == "alt": + guide_custom_choices = model_def.get("guide_custom_choices",{}) + letter_filter = guide_custom_choices.get("letters_filter","") + else: + letter_filter = all_guide_processes + video_prompt_type = del_in_sequence(video_prompt_type, letter_filter) video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide) visible = "V" in video_prompt_type - model_type = state["model_type"] - base_model_type = get_base_model_type(model_type) + any_outpainting= image_mode in model_def.get("video_guide_outpainting", []) mask_visible = visible and "A" in video_prompt_type and not "U" in video_prompt_type - model_def = get_model_def(model_type) - image_outputs = image_mode == 1 - vace= test_vace_module(model_type) + image_outputs = image_mode > 0 keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) - return video_prompt_type, gr.update(visible = visible and not image_outputs), gr.update(visible = visible and image_outputs), gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and vace), gr.update(visible= visible and not "U" in video_prompt_type ), gr.update(visible= mask_visible and not image_outputs), gr.update(visible= mask_visible and image_outputs), gr.update(visible= mask_visible) - -def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt): - video_prompt_type = del_in_sequence(video_prompt_type, "UVQKI") - video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) - control_video_visible = "V" in video_prompt_type + image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) + # mask_video_input_visible = image_mode == 0 and mask_visible + mask_preprocessing = model_def.get("mask_preprocessing", None) + if mask_preprocessing is not None: + mask_selector_visible = mask_preprocessing.get("visible", True) + else: + mask_selector_visible = True ref_images_visible = "I" in video_prompt_type - return video_prompt_type, gr.update(visible = control_video_visible), gr.update(visible = ref_images_visible ) - -# def refresh_video_prompt_video_guide_trigger(state, video_prompt_type, video_prompt_type_video_guide): -# video_prompt_type_video_guide = video_prompt_type_video_guide.split("#")[0] -# return refresh_video_prompt_type_video_guide(state, video_prompt_type, video_prompt_type_video_guide) + custom_options = custom_checkbox = False + custom_video_selection = model_def.get("custom_video_selection", None) + if custom_video_selection is not None: + custom_trigger = custom_video_selection.get("trigger","") + if len(custom_trigger) == 0 or custom_trigger in video_prompt_type: + custom_options = True + custom_checkbox = custom_video_selection.get("type","") == "checkbox" + + return video_prompt_type, gr.update(visible = visible and not image_outputs), image_guide, gr.update(visible = keep_frames_video_guide_visible), gr.update(visible = visible and "G" in video_prompt_type), gr.update(visible= (visible or "F" in video_prompt_type or "K" in video_prompt_type) and any_outpainting), gr.update(visible= visible and mask_selector_visible and not "U" in video_prompt_type ) , gr.update(visible= mask_visible and not image_outputs), image_mask, image_mask_guide, gr.update(visible= mask_visible), gr.update(visible = ref_images_visible ), gr.update(visible= custom_options and not custom_checkbox ), gr.update(visible= custom_options and custom_checkbox ) + +def refresh_video_prompt_type_video_custom_dropbox(state, video_prompt_type, video_prompt_type_video_custom_dropbox): + model_type = state["model_type"] + model_def = get_model_def(model_type) + custom_video_selection = model_def.get("custom_video_selection", None) + if custom_video_selection is None: return gr.update() + letters_filter = custom_video_selection.get("letters_filter", "") + video_prompt_type = del_in_sequence(video_prompt_type, letters_filter) + video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_custom_dropbox) + return video_prompt_type + +def refresh_video_prompt_type_video_custom_checkbox(state, video_prompt_type, video_prompt_type_video_custom_checkbox): + model_type = state["model_type"] + model_def = get_model_def(model_type) + custom_video_selection = model_def.get("custom_video_selection", None) + if custom_video_selection is None: return gr.update() + letters_filter = custom_video_selection.get("letters_filter", "") + video_prompt_type = del_in_sequence(video_prompt_type, letters_filter) + if video_prompt_type_video_custom_checkbox: + video_prompt_type = add_to_sequence(video_prompt_type, custom_video_selection["choices"][1][1]) + return video_prompt_type + + + +# def refresh_video_prompt_type_video_guide_alt(state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ): +# old_video_prompt_type = video_prompt_type +# model_def = get_model_def(state["model_type"]) +# guide_custom_choices = model_def.get("guide_custom_choices",{}) +# video_prompt_type = del_in_sequence(video_prompt_type, guide_custom_choices.get("letters_filter","")) +# video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide_alt) +# image_outputs = image_mode > 0 +# control_video_visible = "V" in video_prompt_type +# ref_images_visible = "I" in video_prompt_type +# denoising_strength_visible = "G" in video_prompt_type +# mask_expand_visible = control_video_visible and "A" in video_prompt_type and not "U" in video_prompt_type +# mask_video_input_visible = image_mode == 0 and mask_expand_visible +# image_mask_guide, image_guide, image_mask = switch_image_guide_editor(image_mode, old_video_prompt_type , video_prompt_type, old_image_mask_guide_value, old_image_guide_value, old_image_mask_value ) +# keep_frames_video_guide_visible = not image_outputs and visible and not model_def.get("keep_frames_video_guide_not_supported", False) + +# return video_prompt_type, gr.update(visible = control_video_visible and image_mode ==0), gr.update(visible = control_video_visible and image_mode >=1), gr.update(visible = ref_images_visible ), gr.update(visible = denoising_strength_visible ), gr.update(visible = mask_video_input_visible ), gr.update(visible = mask_expand_visible), image_mask_guide, image_guide, image_mask, gr.update(visible = keep_frames_video_guide_visible) def refresh_preview(state): gen = get_gen_info(state) @@ -6645,9 +7137,12 @@ def get_prompt_labels(multi_prompts_gen_type, image_outputs = False): new_line_text = "each new line of prompt will be used for a window" if multi_prompts_gen_type != 0 else "each new line of prompt will generate " + ("a new image" if image_outputs else "a new video") return "Prompts (" + new_line_text + ", # lines = comments, ! lines = macros)", "Prompts (" + new_line_text + ", # lines = comments)" +def get_image_end_label(multi_prompts_gen_type): + return "Images as ending points for new Videos in the Generation Queue" if multi_prompts_gen_type == 0 else "Images as ending points for each new Window of the same Video Generation" + def refresh_prompt_labels(multi_prompts_gen_type, image_mode): - prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode == 1) - return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label) + prompt_label, wizard_prompt_label = get_prompt_labels(multi_prompts_gen_type, image_mode > 0) + return gr.update(label=prompt_label), gr.update(label = wizard_prompt_label), gr.update(label=get_image_end_label(multi_prompts_gen_type)) def show_preview_column_modal(state, column_no): column_no = int(column_no) @@ -6779,28 +7274,38 @@ def categorize_resolution(resolution_str): return group return "1440p" -def group_resolutions(resolutions, selected_resolution): - - grouped_resolutions = {} - for resolution in resolutions: - group = categorize_resolution(resolution[1]) - if group not in grouped_resolutions: - grouped_resolutions[group] = [] - grouped_resolutions[group].append(resolution) - - available_groups = [group for group in group_thresholds if group in grouped_resolutions] +def group_resolutions(model_def, resolutions, selected_resolution): + + model_resolutions = model_def.get("resolutions", None) + if model_resolutions is not None: + selected_group ="Locked" + available_groups = [selected_group ] + selected_group_resolutions = model_resolutions + else: + grouped_resolutions = {} + for resolution in resolutions: + group = categorize_resolution(resolution[1]) + if group not in grouped_resolutions: + grouped_resolutions[group] = [] + grouped_resolutions[group].append(resolution) + + available_groups = [group for group in group_thresholds if group in grouped_resolutions] - selected_group = categorize_resolution(selected_resolution) - selected_group_resolutions = grouped_resolutions.get(selected_group, []) - available_groups.reverse() + selected_group = categorize_resolution(selected_resolution) + selected_group_resolutions = grouped_resolutions.get(selected_group, []) + available_groups.reverse() return available_groups, selected_group_resolutions, selected_group def change_resolution_group(state, selected_group): model_type = state["model_type"] model_def = get_model_def(model_type) model_resolutions = model_def.get("resolutions", None) - resolution_choices, _ = get_resolution_choices(None, model_resolutions) - group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ] + resolution_choices, _ = get_resolution_choices(None, model_resolutions) + if model_resolutions is None: + group_resolution_choices = [ resolution for resolution in resolution_choices if categorize_resolution(resolution[1]) == selected_group ] + else: + last_resolution = group_resolution_choices[0][1] + return gr.update(choices= group_resolution_choices, value= last_resolution) last_resolution_per_group = state["last_resolution_per_group"] last_resolution = last_resolution_per_group.get(selected_group, "") @@ -6811,6 +7316,11 @@ def change_resolution_group(state, selected_group): def record_last_resolution(state, resolution): + + model_type = state["model_type"] + model_def = get_model_def(model_type) + model_resolutions = model_def.get("resolutions", None) + if model_resolutions is not None: return server_config["last_resolution_choice"] = resolution selected_group = categorize_resolution(resolution) last_resolution_per_group = state["last_resolution_per_group"] @@ -6846,11 +7356,21 @@ def detect_auto_save_form(state, evt:gr.SelectData): return gr.update() def compute_video_length_label(fps, current_video_length): - return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s", + if fps is None: + return f"Number of frames" + else: + return f"Number of frames ({fps} frames = 1s), current duration: {(current_video_length / fps):.1f}s", + +def refresh_video_length_label(state, current_video_length, force_fps, video_guide, video_source): + base_model_type = get_base_model_type(state["model_type"]) + computed_fps = get_computed_fps(force_fps, base_model_type , video_guide, video_source ) + return gr.update(label= compute_video_length_label(computed_fps, current_video_length)) -def refresh_video_length_label(state, current_video_length): - fps = get_model_fps(get_base_model_type(state["model_type"])) - return gr.update(label= compute_video_length_label(fps, current_video_length)) +def get_default_value(choices, current_value, default_value = None): + for label, value in choices: + if value == current_value: + return current_value + return default_value def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_family = None, model_choice = None, header = None, main = None, main_tabs= None): global inputs_names #, advanced @@ -6963,7 +7483,6 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non diffusion_forcing = "diffusion_forcing" in model_filename ltxv = "ltxv" in model_filename lock_inference_steps = model_def.get("lock_inference_steps", False) - model_reference_image = model_def.get("reference_image", False) any_tea_cache = model_def.get("tea_cache", False) any_mag_cache = model_def.get("mag_cache", False) recammaster = base_model_type in ["recam_1.3B"] @@ -6971,13 +7490,9 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non phantom = base_model_type in ["phantom_1.3B", "phantom_14B"] fantasy = base_model_type in ["fantasy"] multitalk = model_def.get("multitalk_class", False) - standin = model_def.get("standin_class", False) infinitetalk = base_model_type in ["infinitetalk"] hunyuan_t2v = "hunyuan_video_720" in model_filename hunyuan_i2v = "hunyuan_video_i2v" in model_filename - hunyuan_video_custom = "hunyuan_video_custom" in model_filename - hunyuan_video_custom = base_model_type in ["hunyuan_custom", "hunyuan_custom_audio", "hunyuan_custom_edit"] - hunyuan_video_custom_audio = base_model_type in ["hunyuan_custom_audio"] hunyuan_video_custom_edit = base_model_type in ["hunyuan_custom_edit"] hunyuan_video_avatar = "hunyuan_video_avatar" in model_filename flux = base_model_family in ["flux"] @@ -6990,321 +7505,267 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non fps = get_model_fps(base_model_type) image_prompt_type_value = "" video_prompt_type_value = "" - any_start_image = False - any_end_image = False - any_reference_image = False - v2i_switch_supported = (vace or t2v or standin) and not image_outputs + any_start_image = any_end_image = any_reference_image = any_image_mask = False + v2i_switch_supported = model_def.get("v2i_switch_supported", False) and not image_outputs ti2v_2_2 = base_model_type in ["ti2v_2_2"] + gallery_height = 350 + def get_image_gallery(label ="", value = None, single_image_mode = False, visible = False ): + with gr.Row(visible = visible) as gallery_row: + gallery_amg = AdvancedMediaGallery(media_mode="image", height=gallery_height, columns=4, label=label, initial = value , single_image_mode = single_image_mode ) + gallery_amg.mount(update_form=update_form) + return gallery_row, gallery_amg.gallery, [gallery_row] + gallery_amg.get_toggable_elements() image_mode_value = ui_defaults.get("image_mode", 1 if image_outputs else 0 ) if not v2i_switch_supported and not image_outputs: image_mode_value = 0 else: - image_outputs = image_mode_value == 1 + image_outputs = image_mode_value > 0 + inpaint_support = model_def.get("inpaint_support", False) image_mode = gr.Number(value =image_mode_value, visible = False) - - with gr.Tabs(visible = v2i_switch_supported, selected= "t2i" if image_mode_value == 1 else "t2v" ) as image_mode_tabs: - with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab"): + image_mode_tab_selected= "t2i" if image_mode_value == 1 else ("inpaint" if image_mode_value == 2 else "t2v") + with gr.Tabs(visible = v2i_switch_supported or inpaint_support, selected= image_mode_tab_selected ) as image_mode_tabs: + with gr.Tab("Text to Video", id = "t2v", elem_classes="compact_tab", visible = v2i_switch_supported) as tab_t2v: pass with gr.Tab("Text to Image", id = "t2i", elem_classes="compact_tab"): pass + with gr.Tab("Image Inpainting", id = "inpaint", elem_classes="compact_tab", visible=inpaint_support) as tab_inpaint: + pass - - with gr.Column(visible= test_class_i2v(model_type) or hunyuan_i2v or diffusion_forcing or ltxv or recammaster or vace or ti2v_2_2) as image_prompt_column: - if vace or infinitetalk: - image_prompt_type_value= ui_defaults.get("image_prompt_type","") - image_prompt_type_value = "" if image_prompt_type_value == "S" else image_prompt_type_value - image_prompt_type = gr.Radio( [("New Video", ""),("Continue Video File", "V"),("Continue Last Video", "L")], value =image_prompt_type_value, label="Source Video", show_label= False, visible= not image_outputs , scale= 3) - - image_start = gr.Gallery(visible = False) - image_end = gr.Gallery(visible = False) - video_source = gr.Video(label= "Video Source", visible = "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None)) - model_mode = gr.Dropdown(visible = False) - keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VLG"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) + image_prompt_types_allowed = model_def.get("image_prompt_types_allowed", "") + model_mode_choices = model_def.get("model_modes", None) + model_modes_visibility = [0,1,2] + if model_mode_choices is not None: model_modes_visibility= model_mode_choices.get("image_modes", model_modes_visibility) + + with gr.Column(visible= image_mode_value == 0 and len(image_prompt_types_allowed)> 0 or model_mode_choices is not None and image_mode_value in model_modes_visibility ) as image_prompt_column: + # Video Continue / Start Frame / End Frame + image_prompt_type_value= ui_defaults.get("image_prompt_type","") + image_prompt_type = gr.Text(value= image_prompt_type_value, visible= False) + image_prompt_type_choices = [] + if "T" in image_prompt_types_allowed: + image_prompt_type_choices += [("Text Prompt Only" if "S" in image_prompt_types_allowed else "New Video", "")] + if "S" in image_prompt_types_allowed: + image_prompt_type_choices += [("Start Video with Image", "S")] + any_start_image = True + if "V" in image_prompt_types_allowed: any_video_source = True - - elif diffusion_forcing or ltxv or ti2v_2_2: - image_prompt_type_value= ui_defaults.get("image_prompt_type","T") - # image_prompt_type = gr.Radio( [("Start Video with Image", "S"),("Start and End Video with Images", "SE"), ("Continue Video", "V"),("Text Prompt Only", "T")], value =image_prompt_type_value, label="Location", show_label= False, visible= True, scale= 3) - image_prompt_type_choices = [("Text Prompt Only", "T"),("Start Video with Image", "S")] - if ltxv: - image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] - if sliding_window_enabled: - any_video_source = True - image_prompt_type_choices += [("Continue Video", "V")] - image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= True , scale= 3) - - # image_start = gr.Image(label= "Image as a starting point for a new video", type ="pil",value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) - image_start = gr.Gallery(preview= True, - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - image_end = gr.Gallery(preview= True, - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) - video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - if not diffusion_forcing: - model_mode = gr.Dropdown( - choices=[ - ], value=None, - visible= False - ) - else: - model_mode = gr.Dropdown( - choices=[ - ("Synchronous", 0), - ("Asynchronous (better quality but around 50% extra steps added)", 5), - ], - value=ui_defaults.get("model_mode", 0), - label="Generation Type", scale = 3, - visible= True - ) - keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= "V" in image_prompt_type_value, scale = 2, label= "Truncate Video beyond this number of Frames of Video (empty=Keep All)" ) - elif recammaster: - image_prompt_type = gr.Radio(choices=[("Source Video", "V")], value="V") - image_start = gr.Gallery(value = None, visible = False) - image_end = gr.Gallery(value = None, visible= False) - video_source = gr.Video(label= "Video Source", visible = True, value= ui_defaults.get("video_source", None),) - model_mode = gr.Dropdown( - choices=[ - ("Pan Right", 1), - ("Pan Left", 2), - ("Tilt Up", 3), - ("Tilt Down", 4), - ("Zoom In", 5), - ("Zoom Out", 6), - ("Translate Up (with rotation)", 7), - ("Translate Down (with rotation)", 8), - ("Arc Left (with rotation)", 9), - ("Arc Right (with rotation)", 10), - ], - value=ui_defaults.get("model_mode", 1), - label="Camera Movement Type", scale = 3, - visible= True - ) - keep_frames_video_source = gr.Text(visible=False) - else: - if test_class_i2v(model_type) or hunyuan_i2v: - # image_prompt_type_value= ui_defaults.get("image_prompt_type","SE" if flf2v else "S" ) - image_prompt_type_value= ui_defaults.get("image_prompt_type","S" ) - image_prompt_type_choices = [("Start Video with Image", "S")] - image_prompt_type_choices += [("Use both a Start and an End Image", "SE")] - if not hunyuan_i2v: - any_video_source = True - image_prompt_type_choices += [("Continue Video", "V")] - - image_prompt_type = gr.Radio( image_prompt_type_choices, value =image_prompt_type_value, label="Location", show_label= False, visible= not hunyuan_i2v, scale= 3) - any_start_image = True - any_end_image = True - image_start = gr.Gallery(preview= True, - label="Images as starting points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, value= ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value) - - image_end = gr.Gallery(preview= True, - label="Images as ending points for new videos", type ="pil", #file_types= "image", - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None)) - if hunyuan_i2v: - video_source = gr.Video(value=None, visible=False) + image_prompt_type_choices += [("Continue Video", "V")] + if "L" in image_prompt_types_allowed: + any_video_source = True + image_prompt_type_choices += [("Continue Last Video", "L")] + with gr.Group(visible= len(image_prompt_types_allowed)>1 and image_mode_value == 0) as image_prompt_type_group: + with gr.Row(): + image_prompt_type_radio_allowed_values= filter_letters(image_prompt_types_allowed, "SVL") + image_prompt_type_radio_value = filter_letters(image_prompt_type_value, image_prompt_type_radio_allowed_values, image_prompt_type_choices[0][1] if len(image_prompt_type_choices) > 0 else "") + if len(image_prompt_type_choices) > 0: + image_prompt_type_radio = gr.Radio( image_prompt_type_choices, value = image_prompt_type_radio_value, label="Location", show_label= False, visible= len(image_prompt_types_allowed)>1, scale= 3) else: - video_source = gr.Video(label= "Video to Continue", visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) - else: - image_prompt_type = gr.Radio(choices=[("", "")], value="") - image_start = gr.Gallery(value=None) - image_end = gr.Gallery(value=None) - video_source = gr.Video(value=None, visible=False) + image_prompt_type_radio = gr.Radio(choices=[("", "")], value="", visible= False) + if "E" in image_prompt_types_allowed: + image_prompt_type_endcheckbox = gr.Checkbox( value ="E" in image_prompt_type_value, label="End Image(s)", show_label= False, visible= any_letters(image_prompt_type_radio_value, "SVL") and not image_outputs , scale= 1, elem_classes="cbx_centered") + any_end_image = True + else: + image_prompt_type_endcheckbox = gr.Checkbox( value =False, show_label= False, visible= False , scale= 1) + image_start_row, image_start, image_start_extra = get_image_gallery(label= "Images as starting points for new Videos in the Generation Queue", value = ui_defaults.get("image_start", None), visible= "S" in image_prompt_type_value ) + video_source = gr.Video(label= "Video to Continue", height = gallery_height, visible= "V" in image_prompt_type_value, value= ui_defaults.get("video_source", None),) + image_end_row, image_end, image_end_extra = get_image_gallery(label= get_image_end_label(ui_defaults.get("multi_prompts_gen_type", 0)), value = ui_defaults.get("image_end", None), visible= any_letters(image_prompt_type_value, "SVL") and ("E" in image_prompt_type_value) ) + if model_mode_choices is None or image_mode_value not in model_modes_visibility: model_mode = gr.Dropdown(value=None, visible=False) - keep_frames_video_source = gr.Text(visible=False) + else: + model_mode_value = get_default_value(model_mode_choices["choices"], ui_defaults.get("model_mode", None), model_mode_choices["default"] ) + model_mode = gr.Dropdown(choices=model_mode_choices["choices"], value=model_mode_value, label=model_mode_choices["label"], visible=True) + keep_frames_video_source = gr.Text(value=ui_defaults.get("keep_frames_video_source","") , visible= len(filter_letters(image_prompt_type_value, "VL"))>0 , scale = 2, label= "Truncate Video beyond this number of resampled Frames (empty=Keep All, negative truncates from End)" ) + + any_control_video = any_control_image = False + if image_mode_value ==2: + guide_preprocessing = { "selection": ["V", "VG"]} + mask_preprocessing = { "selection": ["A"]} + else: + guide_preprocessing = model_def.get("guide_preprocessing", None) + mask_preprocessing = model_def.get("mask_preprocessing", None) + guide_custom_choices = model_def.get("guide_custom_choices", None) + image_ref_choices = model_def.get("image_ref_choices", None) - with gr.Column(visible= vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or hunyuan_video_custom_edit or t2v or standin or ltxv or infinitetalk or flux and model_reference_image or qwen and model_reference_image) as video_prompt_column: + with gr.Column(visible= guide_preprocessing is not None or mask_preprocessing is not None or guide_custom_choices is not None or image_ref_choices is not None) as video_prompt_column: video_prompt_type_value= ui_defaults.get("video_prompt_type","") video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False) - any_control_video = True - any_control_image = image_outputs - with gr.Row(): - if t2v: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("Use Text Prompt Only", ""), - ("Image to Image guided by Text Prompt" if image_outputs else "Video to Video guided by Text Prompt", "GUV"), - ], - value=filter_letters(video_prompt_type_value, "GUV"), - label="Video to Video", scale = 2, show_label= False, visible= True - ) - elif vace : + with gr.Row(visible = image_mode_value!=2) as guide_selection_row: + # Control Video Preprocessing + if guide_preprocessing is None: + video_prompt_type_video_guide = gr.Dropdown(choices=[("","")], value="", label="Control Video", scale = 2, visible= False, show_label= True, ) + else: pose_label = "Pose" if image_outputs else "Motion" + guide_preprocessing_labels_all = { + "": "No Control Video", + "UV": "Keep Control Video Unchanged", + "PV": f"Transfer Human {pose_label}", + "DV": "Transfer Depth", + "EV": "Transfer Canny Edges", + "SV": "Transfer Shapes", + "LV": "Transfer Flow", + "CV": "Recolorize", + "MV": "Perform Inpainting", + "V": "Use Vace raw format", + "PDV": f"Transfer Human {pose_label} & Depth", + "PSV": f"Transfer Human {pose_label} & Shapes", + "PLV": f"Transfer Human {pose_label} & Flow" , + "DSV": "Transfer Depth & Shapes", + "DLV": "Transfer Depth & Flow", + "SLV": "Transfer Shapes & Flow", + } + guide_preprocessing_choices = [] + guide_preprocessing_labels = guide_preprocessing.get("labels", {}) + for process_type in guide_preprocessing["selection"]: + process_label = guide_preprocessing_labels.get(process_type, None) + process_label = guide_preprocessing_labels_all.get(process_type,process_type) if process_label is None else process_label + if image_outputs: process_label = process_label.replace("Video", "Image") + guide_preprocessing_choices.append( (process_label, process_type) ) + + video_prompt_type_video_guide_label = guide_preprocessing.get("label", "Control Video Process") + if image_outputs: video_prompt_type_video_guide_label = video_prompt_type_video_guide_label.replace("Video", "Image") video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("No Control Image" if image_outputs else "No Control Video", ""), - ("Keep Control Image Unchanged" if image_outputs else "Keep Control Video Unchanged", "UV"), - (f"Transfer Human {pose_label}" , "PV"), - ("Transfer Depth", "DV"), - ("Transfer Shapes", "SV"), - ("Transfer Flow", "LV"), - ("Recolorize", "CV"), - ("Perform Inpainting", "MV"), - ("Use Vace raw format", "V"), - (f"Transfer Human {pose_label} & Depth", "PDV"), - (f"Transfer Human {pose_label} & Shapes", "PSV"), - (f"Transfer Human {pose_label} & Flow", "PLV"), - ("Transfer Depth & Shapes", "DSV"), - ("Transfer Depth & Flow", "DLV"), - ("Transfer Shapes & Flow", "SLV"), - ], - value=filter_letters(video_prompt_type_value, "PDSLCMGUV"), - label="Control Image Process" if image_outputs else "Control Video Process", scale = 2, visible= True, show_label= True, - ) - elif ltxv: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("No Control Video", ""), - ("Transfer Human Motion", "PV"), - ("Transfer Depth", "DV"), - ("Transfer Canny Edges", "EV"), - ("Use LTXV raw format", "V"), - ], - value=filter_letters(video_prompt_type_value, "PDEV"), - label="Control Video Process", scale = 2, visible= True, show_label= True, + guide_preprocessing_choices, + value=filter_letters(video_prompt_type_value, all_guide_processes, guide_preprocessing.get("default", "") ), + label= video_prompt_type_video_guide_label , scale = 1, visible= guide_preprocessing.get("visible", True) , show_label= True, ) + any_control_video = True + any_control_image = image_outputs - elif hunyuan_video_custom_edit: - video_prompt_type_video_guide = gr.Dropdown( - choices=[ - ("Inpaint Control Image" if image_outputs else "Inpaint Control Video", "MV"), - ("Transfer Human Motion", "PMV"), - ], - value=filter_letters(video_prompt_type_value, "PDSLCMUV"), - label="Image to Image" if image_outputs else "Video to Video", scale = 3, visible= True, show_label= True, - ) - elif infinitetalk: - video_prompt_type_video_guide = gr.Dropdown(value="", choices = [("","")], visible=False) + # Alternate Control Video Preprocessing / Options + if guide_custom_choices is None: + video_prompt_type_video_guide_alt = gr.Dropdown(choices=[("","")], value="", label="Control Video", visible= False, scale = 1 ) else: - any_control_video = False - any_control_image = False - video_prompt_type_video_guide = gr.Dropdown(visible= False) - - if infinitetalk: + video_prompt_type_video_guide_alt_label = guide_custom_choices.get("label", "Control Video Process") + if image_outputs: video_prompt_type_video_guide_alt_label = video_prompt_type_video_guide_alt_label.replace("Video", "Image") + video_prompt_type_video_guide_alt_choices = [(label.replace("Video", "Image") if image_outputs else label, value) for label,value in guide_custom_choices["choices"] ] + guide_custom_choices_value = get_default_value(video_prompt_type_video_guide_alt_choices, filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"]), guide_custom_choices.get("default", "") ) video_prompt_type_video_guide_alt = gr.Dropdown( - choices=[ - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Smooth Transitions", "UV"), - ("Sparse Video to Video, one Image will by extracted from Video for each new Sliding Window - Sharp Transitions", "QUV"), - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Smooth Transitions", "KI"), - ("Images to Video, each Reference Image will start a new shot with a new Sliding Window - Sharp Transitions", "QKI"), - ], - value=filter_letters(video_prompt_type_value, "UVQKI"), - label="Video to Video", scale = 3, visible= True, show_label= False, - ) + choices= video_prompt_type_video_guide_alt_choices, + # value=filter_letters(video_prompt_type_value, guide_custom_choices["letters_filter"], guide_custom_choices.get("default", "") ), + value=guide_custom_choices_value, + visible = guide_custom_choices.get("visible", True), + label= video_prompt_type_video_guide_alt_label, show_label= guide_custom_choices.get("show_label", True), scale = guide_custom_choices.get("scale", 1), + ) + any_control_video = True + any_control_image = image_outputs + + # Custom dropdown box & checkbox + custom_video_selection = model_def.get("custom_video_selection", None) + custom_checkbox= False + if custom_video_selection is None: + video_prompt_type_video_custom_dropbox = gr.Dropdown(choices=[("","")], value="", label="Custom Dropdown", scale = 1, visible= False, show_label= True, ) + video_prompt_type_video_custom_checkbox = gr.Checkbox(value=False, label="Custom Checkbbox", scale = 1, visible= False, show_label= True, ) + else: - video_prompt_type_video_guide_alt = gr.Dropdown(value="", choices = [("","")], visible=False) + custom_video_choices = custom_video_selection["choices"] + custom_video_trigger = custom_video_selection.get("trigger", "") + custom_choices = len(custom_video_trigger) == 0 or custom_video_trigger in video_prompt_type_value + custom_checkbox = custom_video_selection.get("type","") == "checkbox" + + video_prompt_type_video_custom_label = custom_video_selection.get("label", "Custom Choices") + video_prompt_type_video_custom_dropbox = gr.Dropdown( + custom_video_choices, + value=filter_letters(video_prompt_type_value, custom_video_selection.get("letters_filter", ""), custom_video_selection.get("default", "")), + scale = custom_video_selection.get("scale", 1), + label= video_prompt_type_video_custom_label , visible= not custom_checkbox and custom_choices, + show_label= custom_video_selection.get("show_label", True), + ) + video_prompt_type_video_custom_checkbox = gr.Checkbox(value= custom_video_choices[1][1] in video_prompt_type_value , label=custom_video_choices[1][0] , scale = custom_video_selection.get("scale", 1), visible=custom_checkbox and custom_choices, show_label= True, elem_classes="cbx_centered" ) - # video_prompt_video_guide_trigger = gr.Text(visible=False, value="") - if t2v: - video_prompt_type_video_mask = gr.Dropdown(value = "", choices = [""], visible = False) - elif hunyuan_video_custom_edit: + # Control Mask Preprocessing + if mask_preprocessing is None: + video_prompt_type_video_mask = gr.Dropdown(choices=[("","")], value="", label="Video Mask", scale = 1, visible= False, show_label= True, ) + any_image_mask = image_outputs + else: + mask_preprocessing_labels_all = { + "": "Whole Frame", + "A": "Masked Area", + "NA": "Non Masked Area", + "XA": "Masked Area, rest Inpainted", + "XNA": "Non Masked Area, rest Inpainted", + "YA": "Masked Area, rest Depth", + "YNA": "Non Masked Area, rest Depth", + "WA": "Masked Area, rest Shapes", + "WNA": "Non Masked Area, rest Shapes", + "ZA": "Masked Area, rest Flow", + "ZNA": "Non Masked Area, rest Flow" + } + + mask_preprocessing_choices = [] + mask_preprocessing_labels = mask_preprocessing.get("labels", {}) + for process_type in mask_preprocessing["selection"]: + process_label = mask_preprocessing_labels.get(process_type, None) + process_label = mask_preprocessing_labels_all.get(process_type, process_type) if process_label is None else process_label + mask_preprocessing_choices.append( (process_label, process_type) ) + + video_prompt_type_video_mask_label = mask_preprocessing.get("label", "Area Processed") video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ], - value= filter_letters(video_prompt_type_value, "NA"), - visible= "V" in video_prompt_type_value, - label="Area Processed", scale = 2, show_label= True, + mask_preprocessing_choices, + value=filter_letters(video_prompt_type_value, "XYZWNA", mask_preprocessing.get("default", "")), + label= video_prompt_type_video_mask_label , scale = 1, visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and mask_preprocessing.get("visible", True), + show_label= True, ) - elif ltxv: - video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Whole Frame", ""), - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ("Masked Area, rest Inpainted", "XA"), - ("Non Masked Area, rest Inpainted", "XNA"), - ], - value= filter_letters(video_prompt_type_value, "XNA"), - visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value, - label="Area Processed", scale = 2, show_label= True, + + + # Image Refs Selection + if image_ref_choices is None: + video_prompt_type_image_refs = gr.Dropdown( + # choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")], + choices=[ ("None", ""),], + value=filter_letters(video_prompt_type_value, ""), + visible = False, + label="Start / Reference Images", scale = 1 ) + any_reference_image = False else: - video_prompt_type_video_mask = gr.Dropdown( - choices=[ - ("Whole Frame", ""), - ("Masked Area", "A"), - ("Non Masked Area", "NA"), - ("Masked Area, rest Inpainted", "XA"), - ("Non Masked Area, rest Inpainted", "XNA"), - ("Masked Area, rest Depth", "YA"), - ("Non Masked Area, rest Depth", "YNA"), - ("Masked Area, rest Shapes", "WA"), - ("Non Masked Area, rest Shapes", "WNA"), - ("Masked Area, rest Flow", "ZA"), - ("Non Masked Area, rest Flow", "ZNA"), - ], - value= filter_letters(video_prompt_type_value, "XYZWNA"), - visible= "V" in video_prompt_type_value and not "U" in video_prompt_type_value and not hunyuan_video_custom and not ltxv, - label="Area Processed", scale = 2, show_label= True, - ) - image_ref_choices = model_def.get("image_ref_choices", None) - if image_ref_choices is not None: + any_reference_image = True video_prompt_type_image_refs = gr.Dropdown( choices= image_ref_choices["choices"], value=filter_letters(video_prompt_type_value, image_ref_choices["letters_filter"]), - visible = True, - label=image_ref_choices["label"], show_label= True, scale = 2 - ) - elif t2v: - video_prompt_type_image_refs = gr.Dropdown(value="", label="Ref Image", choices=[""], visible =False) - elif vace: - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("None", ""), - ("Inject only People / Objects", "I"), - ("Inject Landscape and then People / Objects", "KI"), - ("Inject Frames and then People / Objects", "FI"), - ], - value=filter_letters(video_prompt_type_value, "KFI"), - visible = True, - label="Reference Images", show_label= True, scale = 2 - ) - elif standin: # and not vace - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("No Reference Image", ""), - ("Reference Image is a Person Face", "I"), - ], - value=filter_letters(video_prompt_type_value, "I"), - visible = True, - show_label=False, - label="Reference Image", scale = 2 + visible = image_ref_choices.get("visible", True), + label=image_ref_choices.get("label", "Inject Reference Images"), show_label= True, scale = 1 ) - elif (flux or qwen) and model_reference_image: - video_prompt_type_image_refs = gr.Dropdown( - choices=[ - ("None", ""), - ("Conditional Images is first Main Subject / Landscape and may be followed by People / Objects", "KI"), - ("Conditional Images are People / Objects", "I"), - ], - value=filter_letters(video_prompt_type_value, "KI"), - visible = True, - show_label=False, - label="Reference Images Combination Method", scale = 2 - ) + image_guide = gr.Image(label= "Control Image", height = 800, type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and ("U" in video_prompt_type_value or not "A" in video_prompt_type_value ) , value= ui_defaults.get("image_guide", None)) + video_guide = gr.Video(label= "Control Video", height = gallery_height, visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + if image_mode_value >= 1: + image_guide_value = ui_defaults.get("image_guide", None) + image_mask_value = ui_defaults.get("image_mask", None) + if image_guide_value is None: + image_mask_guide_value = None else: - video_prompt_type_image_refs = gr.Dropdown( - choices=[ ("None", ""),("Start", "KI"),("Ref Image", "I")], - value=filter_letters(video_prompt_type_value, "KI"), - visible = False, - label="Start / Reference Images", scale = 2 - ) - image_guide = gr.Image(label= "Control Image", type ="pil", visible= image_outputs and "V" in video_prompt_type_value, value= ui_defaults.get("image_guide", None)) - video_guide = gr.Video(label= "Control Video", visible= (not image_outputs) and "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None)) + image_mask_guide_value = { "background" : image_guide_value, "composite" : None} + image_mask_guide_value["layers"] = [] if image_mask_value is None else [rgb_bw_to_rgba_mask(image_mask_value)] + + image_mask_guide = gr.ImageEditor( + label="Control Image to be Inpainted" if image_mode_value == 2 else "Control Image and Mask", + value = image_mask_guide_value, + type='pil', + sources=["upload", "webcam"], + image_mode='RGB', + layers=False, + brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed"), + # fixed_canvas= True, + # width=800, + height=800, + # transforms=None, + # interactive=True, + elem_id="img_editor", + visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value + ) + any_control_image = True + else: + image_mask_guide = gr.ImageEditor(value = None, visible = False, elem_id="img_editor") - denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label="Denoising Strength (the Lower the Closer to the Control Video)", visible = "G" in video_prompt_type_value, show_reset_button= False) + + denoising_strength = gr.Slider(0, 1, value= ui_defaults.get("denoising_strength" ,0.5), step=0.01, label=f"Denoising Strength (the Lower the Closer to the Control {'Image' if image_outputs else 'Video'})", visible = "G" in video_prompt_type_value, show_reset_button= False) keep_frames_video_guide_visible = not image_outputs and "V" in video_prompt_type_value and not model_def.get("keep_frames_video_guide_not_supported", False) keep_frames_video_guide = gr.Text(value=ui_defaults.get("keep_frames_video_guide","") , visible= keep_frames_video_guide_visible , scale = 2, label= "Frames to keep in Control Video (empty=All, 1=first, a:b for a range, space to separate values)" ) #, -1=last - - with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and vace) as video_guide_outpainting_col: + video_guide_outpainting_modes = model_def.get("video_guide_outpainting", []) + with gr.Column(visible= ("V" in video_prompt_type_value or "K" in video_prompt_type_value or "F" in video_prompt_type_value) and image_mode_value in video_guide_outpainting_modes) as video_guide_outpainting_col: video_guide_outpainting_value = ui_defaults.get("video_guide_outpainting","#") video_guide_outpainting = gr.Text(value=video_guide_outpainting_value , visible= False) with gr.Group(): - video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Injected Reference Frames", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) + video_guide_outpainting_checkbox = gr.Checkbox(label="Enable Spatial Outpainting on Control Video, Landscape or Positioned Reference Frames" if image_mode_value == 0 else "Enable Spatial Outpainting on Control Image", value=len(video_guide_outpainting_value)>0 and not video_guide_outpainting_value.startswith("#") ) with gr.Row(visible = not video_guide_outpainting_value.startswith("#")) as video_guide_outpainting_row: video_guide_outpainting_value = video_guide_outpainting_value[1:] if video_guide_outpainting_value.startswith("#") else video_guide_outpainting_value video_guide_outpainting_list = [0] * 4 if len(video_guide_outpainting_value) == 0 else [int(v) for v in video_guide_outpainting_value.split(" ")] @@ -7312,29 +7773,29 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non video_guide_outpainting_bottom = gr.Slider(0, 100, value= video_guide_outpainting_list[1], step=5, label="Bottom %", show_reset_button= False) video_guide_outpainting_left = gr.Slider(0, 100, value= video_guide_outpainting_list[2], step=5, label="Left %", show_reset_button= False) video_guide_outpainting_right = gr.Slider(0, 100, value= video_guide_outpainting_list[3], step=5, label="Right %", show_reset_button= False) - any_image_mask = image_outputs and vace - image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_outputs and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("image_mask", None)) - video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , value= ui_defaults.get("video_mask", None)) + # image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= image_mode_value==1 and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("image_mask", None)) + image_mask = gr.Image(label= "Image Mask Area (for Inpainting, white = Control Area, black = Unchanged)", type ="pil", visible= False, height = gallery_height, value= ui_defaults.get("image_mask", None)) + video_mask = gr.Video(label= "Video Mask Area (for Inpainting, white = Control Area, black = Unchanged)", visible= (not image_outputs) and "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value , height = gallery_height, value= ui_defaults.get("video_mask", None)) mask_expand = gr.Slider(-10, 50, value=ui_defaults.get("mask_expand", 0), step=1, label="Expand / Shrink Mask Area", visible= "V" in video_prompt_type_value and "A" in video_prompt_type_value and not "U" in video_prompt_type_value ) - any_reference_image = vace or phantom or hunyuan_video_custom or hunyuan_video_avatar or infinitetalk or (flux or qwen) and model_reference_image - image_refs = gr.Gallery(preview= True, label ="Start Image" if hunyuan_video_avatar else "Reference Images" + (" (each Image will start a new Clip)" if infinitetalk else ""), - type ="pil", show_label= True, - columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, - value= ui_defaults.get("image_refs", None), - ) - frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames separated by Spaces (1=first, no position for Objects / People)" ) + image_refs_single_image_mode = model_def.get("one_image_ref_needed", False) + image_refs_label = "Start Image" if hunyuan_video_avatar else ("Reference Image" if image_refs_single_image_mode else "Reference Images") + (" (each Image will be associated to a Sliding Window)" if infinitetalk else "") + image_refs_row, image_refs, image_refs_extra = get_image_gallery(label= image_refs_label, value = ui_defaults.get("image_refs", None), visible= "I" in video_prompt_type_value, single_image_mode=image_refs_single_image_mode) + + frames_positions = gr.Text(value=ui_defaults.get("frames_positions","") , visible= "F" in video_prompt_type_value, scale = 2, label= "Positions of Injected Frames (1=first, L=last of a window) no position for other Image Refs)" ) image_refs_relative_size = gr.Slider(20, 100, value=ui_defaults.get("image_refs_relative_size", 50), step=1, label="Rescale Internaly Image Ref (% in relation to Output Video) to change Output Composition", visible = model_def.get("any_image_refs_relative_size", False) and image_outputs) - no_background_removal = model_def.get("no_background_removal", False) + no_background_removal = model_def.get("no_background_removal", False) or image_ref_choices is None + background_removal_label = model_def.get("background_removal_label", "Remove Background behind People / Objects") + remove_background_images_ref = gr.Dropdown( choices=[ ("Keep Backgrounds behind all Reference Images", 0), - ("Remove Backgrounds only behind People / Objects except main Subject / Landscape" if (flux or qwen) else ("Remove Backgrounds behind People / Objects, keep it for Landscape or positioned Frames" if vace else "Remove Backgrounds behind People / Objects") , 1), + (background_removal_label, 1), ], value=0 if no_background_removal else ui_defaults.get("remove_background_images_ref",1), - label="Automatic Removal of Background of People or Objects (Only)", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal + label="Automatic Removal of Background behind People or Objects in Reference Images", scale = 3, visible= "I" in video_prompt_type_value and not no_background_removal ) any_audio_voices_support = any_audio_track(base_model_type) @@ -7349,7 +7810,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non speaker_choices=[("None", "")] if any_single_speaker: speaker_choices += [("One Person Speaking Only", "A")] if any_multi_speakers:speaker_choices += [ - ("Two speakers, Auto Separation of Speakers (will work only if there is little background noise)", "XA"), + ("Two speakers, Auto Separation of Speakers (will work only if Voices are distinct)", "XA"), ("Two speakers, Speakers Audio sources are assumed to be played in a Row", "CAB"), ("Two speakers, Speakers Audio sources are assumed to be played in Parallel", "PAB") ] @@ -7364,6 +7825,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non with gr.Row(visible = any_audio_voices_support and not image_outputs) as audio_guide_row: audio_guide = gr.Audio(value= ui_defaults.get("audio_guide", None), type="filepath", label="Voice to follow", show_download_button= True, visible= any_audio_voices_support and "A" in audio_prompt_type_value ) audio_guide2 = gr.Audio(value= ui_defaults.get("audio_guide2", None), type="filepath", label="Voice to follow #2", show_download_button= True, visible= any_audio_voices_support and "B" in audio_prompt_type_value ) + remove_background_sound = gr.Checkbox(label="Video Motion ignores Background Music (to get a better LipSync)", value="V" in audio_prompt_type_value, visible = any_audio_voices_support and any_letters(audio_prompt_type_value, "ABX") and not image_outputs) with gr.Row(visible = any_audio_voices_support and ("B" in audio_prompt_type_value or "X" in audio_prompt_type_value) and not image_outputs ) as speakers_locations_row: speakers_locations = gr.Text( ui_defaults.get("speakers_locations", "0:45 55:100"), label="Speakers Locations separated by a Space. Each Location = Left:Right or a BBox Left:Top:Right:Bottom", visible= True) @@ -7413,14 +7875,17 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non visible= True, show_label= not on_demand_prompt_enhancer, ) with gr.Row(): - if server_config.get("fit_canvas", 0) == 1: - label = "Max Resolution (As it maybe less depending on video width / height ratio)" + fit_canvas = server_config.get("fit_canvas", 0) + if fit_canvas == 1: + label = "Outer Box Resolution (one dimension may be less to preserve video W/H ratio)" + elif fit_canvas == 2: + label = "Output Resolution (Input Images wil be Cropped if the W/H ratio is different)" else: - label = "Max Resolution (Pixels will be reallocated depending on the output width / height ratio)" + label = "Resolution Budget (Pixels will be reallocated to preserve Inputs W/H ratio)" current_resolution_choice = ui_defaults.get("resolution","832x480") if update_form or last_resolution is None else last_resolution model_resolutions = model_def.get("resolutions", None) resolution_choices, current_resolution_choice = get_resolution_choices(current_resolution_choice, model_resolutions) - available_groups, selected_group_resolutions, selected_group = group_resolutions(resolution_choices, current_resolution_choice) + available_groups, selected_group_resolutions, selected_group = group_resolutions(model_def,resolution_choices, current_resolution_choice) resolution_group = gr.Dropdown( choices = available_groups, value= selected_group, @@ -7439,12 +7904,13 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non elif recammaster: video_length = gr.Slider(5, 193, value=ui_defaults.get("video_length", get_max_frames(81)), step=4, label="Number of frames (16 = 1s), locked", interactive= False, visible = True) else: - min_frames, frames_step = get_model_min_frames_and_step(base_model_type) + min_frames, frames_step, _ = get_model_min_frames_and_step(base_model_type) current_video_length = ui_defaults.get("video_length", 81 if get_model_family(base_model_type)=="wan" else 97) + computed_fps = get_computed_fps(ui_defaults.get("force_fps",""), base_model_type , ui_defaults.get("video_guide", None), ui_defaults.get("video_source", None)) video_length = gr.Slider(min_frames, get_max_frames(737 if test_any_sliding_window(base_model_type) else 337), value=current_video_length, - step=frames_step, label=compute_video_length_label(fps, current_video_length) , visible = True, interactive= True) + step=frames_step, label=compute_video_length_label(computed_fps, current_video_length) , visible = True, interactive= True) with gr.Row(visible = not lock_inference_steps) as inference_steps_row: num_inference_steps = gr.Slider(1, 100, value=ui_defaults.get("num_inference_steps",30), step=1, label="Number of Inference Steps", visible = True) @@ -7502,10 +7968,11 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non choices= sample_solver_choices, visible= True, label= "Sampler Solver / Scheduler" ) flow_shift = gr.Slider(1.0, 25.0, value=ui_defaults.get("flow_shift",3), step=0.1, label="Shift Scale", visible = not image_outputs) - - with gr.Row(visible = vace) as control_net_weights_row: - control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Control Net Weight #1", visible=vace) - control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Control Net Weight #2", visible=vace) + control_net_weight_alt_name = model_def.get("control_net_weight_alt_name", "") + with gr.Row(visible = vace or len(control_net_weight_alt_name) >0 ) as control_net_weights_row: + control_net_weight = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight",1), step=0.1, label="Vace Weight #1", visible=vace) + control_net_weight2 = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label="Vace Weight #2", visible=vace) + control_net_weight_alt = gr.Slider(0.0, 2.0, value=ui_defaults.get("control_net_weight2",1), step=0.1, label=control_net_weight_alt_name + " Weight", visible=len(control_net_weight_alt_name) >0) negative_prompt = gr.Textbox(label="Negative Prompt (ignored if no Guidance that is if CFG = 1)", value=ui_defaults.get("negative_prompt", ""), visible = not (hunyuan_t2v or hunyuan_i2v or no_negative_prompt) ) with gr.Column(visible = vace or t2v or test_class_i2v(model_type)) as NAG_col: gr.Markdown("NAG enforces Negative Prompt even if no Guidance is set (CFG = 1), set NAG Scale to > 1 to enable it") @@ -7616,8 +8083,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai MMAudio_neg_prompt = gr.Text(ui_defaults.get("MMAudio_neg_prompt", ""), label="Negative Prompt (1 or 2 keywords)") - with gr.Column(visible = (t2v or vace) and not fantasy) as audio_prompt_type_remux_row: - gr.Markdown("You may transfer the exising audio tracks of a Control Video") + with gr.Column(visible = any_control_video) as audio_prompt_type_remux_row: + gr.Markdown("You may transfer the existing audio tracks of a Control Video") audio_prompt_type_remux = gr.Dropdown( choices=[ ("No Remux", ""), @@ -7692,7 +8159,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai with gr.Row(): cfg_zero_step = gr.Slider(-1, 39, value=ui_defaults.get("cfg_zero_step",-1), step=1, label="CFG Zero below this Layer (Extra Process)", visible = any_cfg_zero) - with gr.Column(visible = (vace or t2v or standin) and image_outputs) as min_frames_if_references_col: + with gr.Column(visible = v2i_switch_supported and image_outputs) as min_frames_if_references_col: gr.Markdown("Generating a single Frame alone may not be sufficient to preserve Reference Image Identity / Control Image Information or simply to get a good Image Quality. A workaround is to generate a short Video and keep the First Frame.") min_frames_if_references = gr.Dropdown( choices=[ @@ -7706,7 +8173,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai ("Generate always a 13 Frames long Video (x2.5 slower)",1013), ("Generate always a 17 Frames long Video (x3.0 slower)",1017), ], - value=ui_defaults.get("min_frames_if_references",5 if vace else 1), + value=ui_defaults.get("min_frames_if_references",9 if vace else 1), visible=True, scale = 1, label="Generate more frames to preserve Reference Image Identity / Control Image Information or improve" @@ -7725,7 +8192,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, visible = False) elif ltxv: sliding_window_size = gr.Slider(41, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=8, label="Sliding Window Size") - sliding_window_overlap = gr.Slider(9, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") + sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",9), step=8, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") sliding_window_color_correction_strength = gr.Slider(0, 1, visible=False, value =0) sliding_window_overlap_noise = gr.Slider(0, 100, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=8, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7736,9 +8203,10 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20), step=1, label="Noise to be added to overlapped frames to reduce blur effect", visible = False) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) else: # Vace, Multitalk + sliding_window_defaults = model_def.get("sliding_window_defaults", {}) sliding_window_size = gr.Slider(5, get_max_frames(257), value=ui_defaults.get("sliding_window_size", 129), step=4, label="Sliding Window Size") - sliding_window_overlap = gr.Slider(1, 97, value=ui_defaults.get("sliding_window_overlap",5), step=4, label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") - sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",1), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)") + sliding_window_overlap = gr.Slider(sliding_window_defaults.get("overlap_min", 1), sliding_window_defaults.get("overlap_max", 97), value=ui_defaults.get("sliding_window_overlap",sliding_window_defaults.get("overlap_default", 5)), step=sliding_window_defaults.get("overlap_step", 4), label="Windows Frames Overlap (needed to maintain continuity between windows, a higher value will require more windows)") + sliding_window_color_correction_strength = gr.Slider(0, 1, value=ui_defaults.get("sliding_window_color_correction_strength",0), step=0.01, label="Color Correction Strength (match colors of new window with previous one, 0 = disabled)", visible = True) sliding_window_overlap_noise = gr.Slider(0, 150, value=ui_defaults.get("sliding_window_overlap_noise",20 if vace else 0), step=1, label="Noise to be added to overlapped frames to reduce blur effect" , visible = vace) sliding_window_discard_last_frames = gr.Slider(0, 20, value=ui_defaults.get("sliding_window_discard_last_frames", 0), step=4, label="Discard Last Frames of a Window (that may have bad quality)", visible = True) @@ -7748,19 +8216,19 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai ("Aligned to the beginning of the First Window of the new Video Sample", "T"), ], value=filter_letters(video_prompt_type_value, "T"), - label="Control Video / Control Audio temporal alignment when any Source Video", + label="Control Video / Control Audio / Positioned Frames Temporal Alignment when any Video to continue", visible = vace or ltxv or t2v or infinitetalk ) multi_prompts_gen_type = gr.Dropdown( choices=[ - ("Will create new generated Video", 0), + ("Will create a new generated Video added to the Generation Queue", 0), ("Will be used for a new Sliding Window of the same Video Generation", 1), ], value=ui_defaults.get("multi_prompts_gen_type",0), visible=True, scale = 1, - label="Text Prompts separated by a Carriage Return" + label="Images & Text Prompts separated by a Carriage Return" if (any_start_image or any_end_image) else "Text Prompts separated by a Carriage Return" ) with gr.Tab("Misc.", visible = True) as misc_tab: @@ -7831,7 +8299,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai output_trigger = gr.Text(interactive= False, visible=False) refresh_form_trigger = gr.Text(interactive= False, visible=False) fill_wizard_prompt_trigger = gr.Text(interactive= False, visible=False) - saveform_trigger = gr.Text(interactive= False, visible=False) + save_form_trigger = gr.Text(interactive= False, visible=False) with gr.Accordion("Video Info and Late Post Processing & Audio Remuxing", open=False) as video_info_accordion: with gr.Tabs() as video_info_tabs: @@ -7840,16 +8308,16 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai video_info = gr.HTML(visible=True, min_height=100, value=get_default_video_info()) with gr.Row(**default_visibility) as video_buttons_row: video_info_extract_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") - video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_to_video_source_btn = gr.Button("To Video Source", min_width= 1, size ="sm", visible = any_video_source) + video_info_to_control_video_btn = gr.Button("To Control Video", min_width= 1, size ="sm", visible = any_control_video ) video_info_eject_video_btn = gr.Button("Eject Video", min_width= 1, size ="sm") with gr.Row(**default_visibility) as image_buttons_row: video_info_extract_image_settings_btn = gr.Button("Extract Settings", min_width= 1, size ="sm") video_info_to_start_image_btn = gr.Button("To Start Image", size ="sm", min_width= 1, visible = any_start_image ) video_info_to_end_image_btn = gr.Button("To End Image", size ="sm", min_width= 1, visible = any_end_image) - video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) - video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask) + video_info_to_image_mask_btn = gr.Button("To Mask Image", min_width= 1, size ="sm", visible = any_image_mask and False) video_info_to_reference_image_btn = gr.Button("To Reference Image", min_width= 1, size ="sm", visible = any_reference_image) + video_info_to_image_guide_btn = gr.Button("To Control Image", min_width= 1, size ="sm", visible = any_control_image ) video_info_eject_image_btn = gr.Button("Eject Image", min_width= 1, size ="sm") with gr.Tab("Post Processing", id= "post_processing", visible = True) as video_postprocessing_tab: with gr.Group(elem_classes= "postprocess"): @@ -7889,7 +8357,7 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai add_to_queue_trigger = gr.Text(visible = False) with gr.Column(visible= False) as current_gen_column: - with gr.Accordion("Preview", open=False) as queue_accordion: + with gr.Accordion("Preview", open=False): preview = gr.Image(label="Preview", height=200, show_label= False) preview_trigger = gr.Text(visible= False) gen_info = gr.HTML(visible=False, min_height=1) @@ -7926,16 +8394,17 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai hidden_countdown_state = gr.Number(value=-1, visible=False, elem_id="hidden_countdown_state_num") single_hidden_trigger_btn = gr.Button("trigger_countdown", visible=False, elem_id="trigger_info_single_btn") - extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, + extra_inputs = prompt_vars + [wizard_prompt, wizard_variables_var, wizard_prompt_activated_var, video_prompt_column, image_prompt_column, image_prompt_type_group, image_prompt_type_radio, image_prompt_type_endcheckbox, prompt_column_advanced, prompt_column_wizard_vars, prompt_column_wizard, lset_name, save_lset_prompt_drop, advanced_row, speed_tab, audio_tab, mmaudio_col, quality_tab, sliding_window_tab, misc_tab, prompt_enhancer_row, inference_steps_row, skip_layer_guidance_row, audio_guide_row, RIFLEx_setting_col, - video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, apg_col, audio_prompt_type_sources, audio_prompt_type_remux_row, + video_prompt_type_video_guide, video_prompt_type_video_guide_alt, video_prompt_type_video_mask, video_prompt_type_image_refs, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox, + apg_col, audio_prompt_type_sources, audio_prompt_type_remux, audio_prompt_type_remux_row, video_guide_outpainting_col,video_guide_outpainting_top, video_guide_outpainting_bottom, video_guide_outpainting_left, video_guide_outpainting_right, video_guide_outpainting_checkbox, video_guide_outpainting_row, show_advanced, video_info_to_control_video_btn, video_info_to_video_source_btn, sample_solver_row, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab, PP_MMAudio_row, PP_custom_audio_row, video_info_to_start_image_btn, video_info_to_end_image_btn, video_info_to_reference_image_btn, video_info_to_image_guide_btn, video_info_to_image_mask_btn, - NAG_col, speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, image_mode_tabs, - min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn] # presets_column, + NAG_col, remove_background_sound , speakers_locations_row, embedded_guidance_row, guidance_phases_row, guidance_row, resolution_group, cfg_free_guidance_col, control_net_weights_row, guide_selection_row, image_mode_tabs, + min_frames_if_references_col, video_prompt_type_alignment, prompt_enhancer_btn, tab_inpaint, tab_t2v] + image_start_extra + image_end_extra + image_refs_extra # presets_column, if update_form: locals_dict = locals() gen_inputs = [state_dict if k=="state" else locals_dict[k] for k in inputs_names] + [state_dict] + extra_inputs @@ -7945,22 +8414,28 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai target_settings = gr.Text(value = "settings", interactive= False, visible= False) last_choice = gr.Number(value =-1, interactive= False, visible= False) - resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution]) + resolution_group.input(fn=change_resolution_group, inputs=[state, resolution_group], outputs=[resolution], show_progress="hidden") resolution.change(fn=record_last_resolution, inputs=[state, resolution]) - video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" ) + # video_length.release(fn=refresh_video_length_label, inputs=[state, video_length ], outputs = video_length, trigger_mode="always_last" ) + gr.on(triggers=[video_length.release, force_fps.change, video_guide.change, video_source.change], fn=refresh_video_length_label, inputs=[state, video_length, force_fps, video_guide, video_source] , outputs = video_length, trigger_mode="always_last", show_progress="hidden" ) guidance_phases.change(fn=change_guidance_phases, inputs= [state, guidance_phases], outputs =[model_switch_phase, guidance_phases_row, switch_threshold, switch_threshold2, guidance2_scale, guidance3_scale ]) audio_prompt_type_remux.change(fn=refresh_audio_prompt_type_remux, inputs=[state, audio_prompt_type, audio_prompt_type_remux], outputs=[audio_prompt_type]) - audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row]) - image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end, video_source, keep_frames_video_source] ) - # video_prompt_video_guide_trigger.change(fn=refresh_video_prompt_video_guide_trigger, inputs=[state, video_prompt_type, video_prompt_video_guide_trigger], outputs=[video_prompt_type, video_prompt_type_video_guide, video_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, mask_expand]) - video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col]) - video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, video_prompt_type, video_prompt_type_video_guide, image_mode], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, mask_expand]) - video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt], outputs = [video_prompt_type, video_guide, image_refs ]) - video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode], outputs = [video_prompt_type, video_mask, image_mask, mask_expand]) + remove_background_sound.change(fn=refresh_remove_background_sound, inputs=[state, audio_prompt_type, remove_background_sound], outputs=[audio_prompt_type]) + audio_prompt_type_sources.change(fn=refresh_audio_prompt_type_sources, inputs=[state, audio_prompt_type, audio_prompt_type_sources], outputs=[audio_prompt_type, audio_guide, audio_guide2, speakers_locations_row, remove_background_sound]) + image_prompt_type_radio.change(fn=refresh_image_prompt_type_radio, inputs=[state, image_prompt_type, image_prompt_type_radio], outputs=[image_prompt_type, image_start_row, image_end_row, video_source, keep_frames_video_source, image_prompt_type_endcheckbox], show_progress="hidden" ) + image_prompt_type_endcheckbox.change(fn=refresh_image_prompt_type_endcheckbox, inputs=[state, image_prompt_type, image_prompt_type_radio, image_prompt_type_endcheckbox], outputs=[image_prompt_type, image_end_row] ) + video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [state, video_prompt_type, video_prompt_type_image_refs,image_mode], outputs = [video_prompt_type, image_refs_row, remove_background_images_ref, image_refs_relative_size, frames_positions,video_guide_outpainting_col], show_progress="hidden") + video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State(""), video_prompt_type, video_prompt_type_video_guide, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox], show_progress="hidden") + video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide, inputs = [state, gr.State("alt"),video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, keep_frames_video_guide, denoising_strength, video_guide_outpainting_col, video_prompt_type_video_mask, video_mask, image_mask, image_mask_guide, mask_expand, image_refs_row, video_prompt_type_video_custom_dropbox, video_prompt_type_video_custom_checkbox], show_progress="hidden") + # video_prompt_type_video_guide_alt.input(fn=refresh_video_prompt_type_video_guide_alt, inputs = [state, video_prompt_type, video_prompt_type_video_guide_alt, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_guide, image_guide, image_refs_row, denoising_strength, video_mask, mask_expand, image_mask_guide, image_guide, image_mask, keep_frames_video_guide ], show_progress="hidden") + video_prompt_type_video_custom_dropbox.input(fn= refresh_video_prompt_type_video_custom_dropbox, inputs=[state, video_prompt_type, video_prompt_type_video_custom_dropbox], outputs = video_prompt_type) + video_prompt_type_video_custom_checkbox.input(fn= refresh_video_prompt_type_video_custom_checkbox, inputs=[state, video_prompt_type, video_prompt_type_video_custom_checkbox], outputs = video_prompt_type) + + video_prompt_type_video_mask.input(fn=refresh_video_prompt_type_video_mask, inputs = [state, video_prompt_type, video_prompt_type_video_mask, image_mode, image_mask_guide, image_guide, image_mask], outputs = [video_prompt_type, video_mask, image_mask_guide, image_guide, image_mask, mask_expand], show_progress="hidden") video_prompt_type_alignment.input(fn=refresh_video_prompt_type_alignment, inputs = [state, video_prompt_type, video_prompt_type_alignment], outputs = [video_prompt_type]) - multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt]) + multi_prompts_gen_type.select(fn=refresh_prompt_labels, inputs=[multi_prompts_gen_type, image_mode], outputs=[prompt, wizard_prompt, image_end], show_progress="hidden") video_guide_outpainting_top.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_top, gr.State(0)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_bottom.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_bottom,gr.State(1)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) video_guide_outpainting_left.input(fn=update_video_guide_outpainting, inputs=[video_guide_outpainting, video_guide_outpainting_left,gr.State(2)], outputs = [video_guide_outpainting], trigger_mode="multiple" ) @@ -7969,8 +8444,8 @@ def gen_upsampling_dropdowns(temporal_upsampling, spatial_upsampling , film_grai show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name]).then( fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars]) queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container]) - gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab]) - preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview]) + gr.on( triggers=[output.change, output.select], fn=select_video, inputs=[state, output], outputs=[last_choice, video_info, video_buttons_row, image_buttons_row, video_postprocessing_tab, audio_remuxing_tab], show_progress="hidden") + preview_trigger.change(refresh_preview, inputs= [state], outputs= [preview], show_progress="hidden") PP_MMAudio_setting.change(fn = lambda value : [gr.update(visible = value == 1), gr.update(visible = value == 0)] , inputs = [PP_MMAudio_setting], outputs = [PP_MMAudio_row, PP_custom_audio_row] ) def refresh_status_async(state, progress=gr.Progress()): gen = get_gen_info(state) @@ -7996,13 +8471,15 @@ def activate_status(state): gen["status_display"] = True return time.time() - start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js = get_js() + start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js, click_brush_js = get_js() status_trigger.change(refresh_status_async, inputs= [state] , outputs= [gen_status], show_progress_on= [gen_status]) output_trigger.change(refresh_gallery, inputs = [state], - outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn]) + outputs = [output, gen_info, generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row, queue_df, abort_btn, onemorewindow_btn], + show_progress="hidden" + ) preview_column_no.input(show_preview_column_modal, inputs=[state, preview_column_no], outputs=[preview_column_no, modal_image_display, modal_container]) @@ -8018,7 +8495,8 @@ def activate_status(state): gr.on( triggers=[video_info_extract_settings_btn.click, video_info_extract_image_settings_btn.click], fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8027,21 +8505,23 @@ def activate_status(state): prompt_enhancer_btn.click(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then( fn=enhance_prompt, inputs =[state, prompt, prompt_enhancer, multi_images_gen_type, override_profile ] , outputs= [prompt, wizard_prompt]) - saveform_trigger.change(fn=validate_wizard_prompt, + save_form_trigger.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ) - main_tabs.select(fn=detect_auto_save_form, inputs= [state], outputs= saveform_trigger, trigger_mode="multiple") + main_tabs.select(fn=detect_auto_save_form, inputs= [state], outputs= save_form_trigger, trigger_mode="multiple") video_info_add_videos_btn.click(fn=add_videos_to_gallery, inputs =[state, output, last_choice, files_to_load], outputs = [output, files_to_load, video_info_tabs] ) @@ -8050,14 +8530,14 @@ def activate_status(state): video_info_to_video_source_btn.click(fn=video_to_source_video, inputs =[state, output, last_choice], outputs = [video_source] ) video_info_to_start_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_start, gr.State("Start Image")], outputs = [image_start] ) video_info_to_end_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_end, gr.State("End Image")], outputs = [image_end] ) - video_info_to_image_guide_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_guide, gr.State("Control Image")], outputs = [image_guide] ) + video_info_to_image_guide_btn.click(fn=image_to_ref_image_guide, inputs =[state, output, last_choice], outputs = [image_guide, image_mask_guide]).then(fn=None, inputs=[], outputs=[], js=click_brush_js ) video_info_to_image_mask_btn.click(fn=image_to_ref_image_set, inputs =[state, output, last_choice, image_mask, gr.State("Image Mask")], outputs = [image_mask] ) video_info_to_reference_image_btn.click(fn=image_to_ref_image_add, inputs =[state, output, last_choice, image_refs, gr.State("Ref Image")], outputs = [image_refs] ) video_info_postprocessing_btn.click(fn=apply_post_processing, inputs =[state, output, last_choice, PP_temporal_upsampling, PP_spatial_upsampling, PP_film_grain_intensity, PP_film_grain_saturation], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) video_info_remux_audio_btn.click(fn=remux_audio, inputs =[state, output, last_choice, PP_MMAudio_setting, PP_MMAudio_prompt, PP_MMAudio_neg_prompt, PP_MMAudio_seed, PP_repeat_generation, PP_custom_audio], outputs = [mode, generate_trigger, add_to_queue_trigger ] ) save_lset_btn.click(validate_save_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_save_lset_btn, cancel_lset_btn, save_lset_prompt_drop]) delete_lset_btn.click(validate_delete_lset, inputs=[state, lset_name], outputs=[apply_lset_btn, refresh_lora_btn, delete_lset_btn, save_lset_btn,confirm_delete_lset_btn, cancel_lset_btn ]) - confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt]).then( + confirm_save_lset_btn.click(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , outputs= [prompt], show_progress="hidden",).then( fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None).then( @@ -8072,7 +8552,8 @@ def activate_status(state): lset_name.select(fn=update_lset_type, inputs=[state, lset_name], outputs=save_lset_prompt_drop) export_settings_from_file_btn.click(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8089,7 +8570,8 @@ def activate_status(state): image_mode_tabs.select(fn=record_image_mode_tab, inputs=[state], outputs= None ).then(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8097,7 +8579,8 @@ def activate_status(state): settings_file.upload(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8111,17 +8594,20 @@ def activate_status(state): refresh_form_trigger.change(fn= fill_inputs, inputs=[state], - outputs=gen_inputs + extra_inputs + outputs=gen_inputs + extra_inputs, + show_progress= "full" if args.debug_gen_form else "hidden", ).then(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ) model_family.input(fn=change_model_family, inputs=[state, model_family], outputs= [model_choice]) model_choice.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None @@ -8130,7 +8616,8 @@ def activate_status(state): outputs= [header] ).then(fn= fill_inputs, inputs=[state], - outputs=gen_inputs + extra_inputs + outputs=gen_inputs + extra_inputs, + show_progress="full" if args.debug_gen_form else "hidden", ).then(fn= preload_model_when_switching, inputs=[state], outputs=[gen_status]) @@ -8139,26 +8626,25 @@ def activate_status(state): generate_trigger.change(fn=validate_wizard_prompt, inputs= [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs= queue_df + outputs= [queue_df, queue_accordion], + show_progress="hidden", ).then(fn=prepare_generate_video, inputs= [state], outputs= [generate_btn, add_to_queue_btn, current_gen_column, current_gen_buttons_row] ).then(fn=activate_status, inputs= [state], outputs= [status_trigger], - ).then( - fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), - inputs=[state], - outputs=[queue_accordion] ).then(fn=process_tasks, inputs= [state], outputs= [preview_trigger, output_trigger], + show_progress="hidden", ).then(finalize_generation, inputs= [state], outputs= [output, abort_btn, generate_btn, add_to_queue_btn, current_gen_column, gen_info] @@ -8265,17 +8751,15 @@ def activate_status(state): # gr.on(triggers=[add_to_queue_btn.click, add_to_queue_trigger.change],fn=validate_wizard_prompt, add_to_queue_trigger.change(fn=validate_wizard_prompt, inputs =[state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars] , - outputs= [prompt] + outputs= [prompt], + show_progress="hidden", ).then(fn=save_inputs, inputs =[target_state] + gen_inputs, outputs= None ).then(fn=process_prompt_and_add_tasks, inputs = [state, model_choice], - outputs=queue_df - ).then( - fn=lambda s: gr.Accordion(open=True) if len(get_gen_info(s).get("queue", [])) > 1 else gr.update(), - inputs=[state], - outputs=[queue_accordion] + outputs=[queue_df, queue_accordion], + show_progress="hidden", ).then( fn=update_status, inputs = [state], @@ -8287,8 +8771,8 @@ def activate_status(state): outputs=[modal_container] ) - return ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + return ( state, loras_choices, lset_name, resolution, refresh_form_trigger, save_form_trigger, + # video_guide, image_guide, video_mask, image_mask, image_refs, ) @@ -8324,8 +8808,9 @@ def generate_configuration_tab(state, blocks, header, model_family, model_choice fit_canvas_choice = gr.Dropdown( choices=[ - ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be resized to match this pixels budget, output video height or width may exceed the requested dimensions )", 0), - ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be resized to fit into these dimensions, the output video may be smaller)", 1), + ("Dimensions correspond to the Pixels Budget (as the Prompt Image/Video will be Resized to match this pixels Budget, output video height or width may exceed the requested dimensions )", 0), + ("Dimensions correspond to the Maximum Width and Height (as the Prompt Image/Video will be Resized to fit into these dimensions, the output video may be smaller)", 1), + ("Dimensions correspond to the Output Width and Height (as the Prompt Image/Video will be Cropped to fit exactly these dimensions)", 2), ], value= server_config.get("fit_canvas", 0), label="Generated Video Dimensions when Prompt contains an Image or a Video", @@ -8347,8 +8832,9 @@ def check(mode): ("Flash" + check("flash")+ ": good quality - requires additional install (usually complex to set up on Windows without WSL)", "flash"), ("Xformers" + check("xformers")+ ": good quality - requires additional install (usually complex, may consume less VRAM to set up on Windows without WSL)", "xformers"), ("Sage" + check("sage")+ ": 30% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage"), - ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2"), - ], + ("Sage2/2++" + check("sage2")+ ": 40% faster but slightly worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage2")]\ + + ([("Radial" + check("radial")+ ": x? faster but ? quality - requires Sparge & Sage 2 Attn (usually complex to set up on Windows without WSL)", "radial")] if args.betatest else [])\ + + [("Sage3" + check("sage3")+ ": x2 faster but worse quality - requires additional install (usually complex to set up on Windows without WSL)", "sage3")], value= attention_mode, label="Attention Type", interactive= not lock_ui_attention @@ -8414,6 +8900,15 @@ def check(mode): label="User Interface Theme. You will need to restart the App the see new Theme." ) + sliding_window_keep_only_longest_choice = gr.Dropdown( + choices=[ + ("Disabled", False), + ("Enabled", True), + ], + value=server_config.get("sliding_window_keep_only_longest", False), + label="Sliding Window Video Cleanup Policy (automatically deletes shorter intermediate videos during generation)" + ) + with gr.Tab("Performance"): @@ -8471,7 +8966,7 @@ def check(mode): ("Off", "" ), ], value= compile, - label="Compile Transformer (up to 50% faster and 30% more frames but requires Linux / WSL and Flash or Sage attention)", + label="Compile Transformer : up to 10-20% faster, useful only if multiple gens at same frames no / resolution", interactive= not lock_ui_compile ) @@ -8650,6 +9145,7 @@ def check(mode): image_output_codec_choice, audio_output_codec_choice, resolution, + sliding_window_keep_only_longest_choice, ], outputs= [msg , header, model_family, model_choice, refresh_form_trigger] ) @@ -8663,7 +9159,7 @@ def generate_about_tab(): gr.Markdown("- Blackforest Labs for the innovative Flux image generators (https://github.com/black-forest-labs/flux)") gr.Markdown("- Alibaba Qwen Team for their state of the art Qwen Image generators (https://github.com/QwenLM/Qwen-Image)") gr.Markdown("- Lightricks for their super fast LTX Video models (https://github.com/Lightricks/LTX-Video)") - gr.Markdown("- Hugging Face for the providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") + gr.Markdown("- Hugging Face for providing hosting for the models and developing must have open source libraries such as Tranformers, Diffusers, Accelerate and Gradio (https://huggingface.co/)") gr.Markdown("
Huge acknowledgments to these great open source projects used in WanGP:") gr.Markdown("- Rife: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)") gr.Markdown("- DwPose: Open Pose extractor (https://github.com/IDEA-Research/DWPose)") @@ -8672,7 +9168,7 @@ def generate_about_tab(): gr.Markdown("- Pyannote: speaker diarization (https://github.com/pyannote/pyannote-audio)") gr.Markdown("
Special thanks to the following people for their support:") - gr.Markdown("- Cocktail Peanuts : QA and simple installation via Pinokio.computer") + gr.Markdown("- Cocktail Peanuts : QA dpand simple installation via Pinokio.computer") gr.Markdown("- Tophness : created (former) multi tabs and queuing frameworks") gr.Markdown("- AmericanPresidentJimmyCarter : added original support for Skip Layer Guidance") gr.Markdown("- Remade_AI : for their awesome Loras collection") @@ -8776,12 +9272,19 @@ def set_new_tab(tab_state, new_tab_no): tab_state["tab_no"] = 0 return gr.Tabs(selected="video_gen") else: + if not download_shared_done: + download_models() vmc_event_handler(True) tab_state["tab_no"] = new_tab_no return gr.Tabs() def select_tab(tab_state, evt:gr.SelectData): - return set_new_tab(tab_state, evt.index) + old_tab_no = tab_state.get("tab_no",0) + if old_tab_no == 0: + saveform_trigger = get_unique_id() + else: + saveform_trigger = gr.update() + return set_new_tab(tab_state, evt.index), saveform_trigger def get_js(): start_quit_timer_js = """ @@ -8884,7 +9387,21 @@ def get_js(): } } """ - return start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js + + click_brush_js = """ + () => { + setTimeout(() => { + const brushButton = document.querySelector('button[aria-label="Brush"]'); + if (brushButton) { + brushButton.click(); + console.log('Brush button clicked'); + } else { + console.log('Brush button not found'); + } + }, 1000); + } """ + + return start_quit_timer_js, cancel_quit_timer_js, trigger_zip_download_js, trigger_settings_download_js, click_brush_js def create_ui(): global vmc_event_handler @@ -9194,6 +9711,7 @@ def create_ui(): transition: visibility 0s linear 1s, opacity 0.3s linear 1s; /* 1s delay before showing */ } .btn_centered {margin-top:10px; text-wrap-mode: nowrap;} + .cbx_centered label {margin-top:8px; text-wrap-mode: nowrap;} """ UI_theme = server_config.get("UI_theme", "default") UI_theme = args.theme if len(args.theme) > 0 else UI_theme @@ -9215,9 +9733,17 @@ def create_ui(): console.log('Events dispatched for column:', index); } }; - console.log('sendColIndex function attached to window'); - } + + // cancel wheel usage inside image editor + const hit = n => n?.id === "img_editor" || n?.classList?.contains("wheel-pass"); + addEventListener("wheel", e => { + const path = e.composedPath?.() || (() => { let a=[],n=e.target; for(;n;n=n.parentNode||n.host) a.push(n); return a; })(); + if (path.some(hit)) e.stopImmediatePropagation(); + }, { capture: true, passive: true }); + + } + """ if server_config.get("display_stats", 0) == 1: from shared.utils.stats import SystemStatsApp @@ -9248,13 +9774,13 @@ def create_ui(): stats_element = stats_app.get_gradio_element() with gr.Row(): - ( state, loras_choices, lset_name, resolution, - video_guide, image_guide, video_mask, image_mask, image_refs, refresh_form_trigger + ( state, loras_choices, lset_name, resolution, refresh_form_trigger, save_form_trigger + # video_guide, image_guide, video_mask, image_mask, image_refs, ) = generate_video_tab(model_family=model_family, model_choice=model_choice, header=header, main = main, main_tabs =main_tabs) with gr.Tab("Guides", id="info") as info_tab: generate_info_tab() with gr.Tab("Video Mask Creator", id="video_mask_creator") as video_mask_creator: - matanyone_app.display(main_tabs, tab_state, server_config, video_guide, image_guide, video_mask, image_mask, image_refs) + matanyone_app.display(main_tabs, tab_state, state, refresh_form_trigger, server_config, get_current_model_settings) #, video_guide, image_guide, video_mask, image_mask, image_refs) if not args.lock_config: with gr.Tab("Downloads", id="downloads") as downloads_tab: generate_download_tab(lset_name, loras_choices, state) @@ -9264,7 +9790,7 @@ def create_ui(): generate_about_tab() if stats_app is not None: stats_app.setup_events(main, state) - main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= main_tabs, trigger_mode="multiple") + main_tabs.select(fn=select_tab, inputs= [tab_state], outputs= [main_tabs, save_form_trigger], trigger_mode="multiple") return main if __name__ == "__main__": @@ -9290,4 +9816,4 @@ def create_ui(): else: url = "http://" + server_name webbrowser.open(url + ":" + str(server_port), new = 0, autoraise = True) - demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path})) \ No newline at end of file + demo.launch(favicon_path="favicon.png", server_name=server_name, server_port=server_port, share=args.share, allowed_paths=list({save_path, image_save_path}))