Skip to content

Conversation

@yiyixuxu
Copy link
Collaborator

@yiyixuxu yiyixuxu commented Nov 22, 2025

https://huggingface.co/collections/hunyuanvideo-community/hunyuanvideo-15

testing script

import torch

dtype = torch.bfloat16
device = "cuda:0"
from diffusers import HunyuanVideo15Pipeline, HunyuanVideo15ImageToVideoPipeline
from diffusers.utils import export_to_video, load_image

t2v_names = ["480p_t2v", "720p_t2v", "480p_t2v_distilled"]
num_frames = 31  # use a minimum number for testing, 121 is default

# test t2v
prompt="A close-up shot captures a scene on a polished, light-colored granite kitchen counter, illuminated by soft natural light from an unseen window. Initially, the frame focuses on a tall, clear glass filled with golden, translucent apple juice standing next to a single, shiny red apple with a green leaf still attached to its stem. The camera moves horizontally to the right. As the shot progresses, a white ceramic plate smoothly enters the frame, revealing a fresh arrangement of about seven or eight more apples, a mix of vibrant reds and greens, piled neatly upon it. A shallow depth of field keeps the focus sharply on the fruit and glass, while the kitchen backsplash in the background remains softly blurred. The scene is in a realistic style."
seed = 1
for name in t2v_names:
    print(f"Testing {name}...")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    pipe = HunyuanVideo15Pipeline.from_pretrained(f"hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-{name}", torch_dtype=dtype)
    pipe.enable_model_cpu_offload()
    pipe.vae.enable_tiling()

    generator = torch.Generator(device=device).manual_seed(seed)
    video = pipe(
        prompt=prompt,
        generator=generator,
        num_frames=num_frames,
        num_inference_steps=50,
    ).frames[0]
    export_to_video(video, f"yiyi_test_hy15_{name}_output.mp4", fps=24)
    max_allocated = torch.cuda.max_memory_allocated() / 1024**3  # GB
    print(f"Max Allocated Memory: {max_allocated:.2f} GB")
    
# test i2v
i2v_names = ["480p_i2v", "720p_i2v", "480p_i2v_distilled", "720p_i2v_distilled"]

image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG")
prompt="Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
seed = 1
for name in i2v_names:
    print(f"Testing {name}...")
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    pipe = HunyuanVideo15ImageToVideoPipeline.from_pretrained(f"hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-{name}", torch_dtype=dtype)
    pipe.enable_model_cpu_offload()
    pipe.vae.enable_tiling()

    generator = torch.Generator(device=device).manual_seed(seed)
    video = pipe(
        prompt=prompt,
        generator=generator,
        image=image,
        num_frames=num_frames,
        num_inference_steps=50,
    ).frames[0]
    export_to_video(video, f"yiyi_test_hy15_{name}_output.mp4", fps=24)
    max_allocated = torch.cuda.max_memory_allocated() / 1024**3  # GB
    print(f"Max Allocated Memory: {max_allocated:.2f} GB")

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tin2tin
Copy link

tin2tin commented Nov 28, 2025

Thank you for working on this. Eagerly awaiting this one (Wan doesn't work for me).

@yiyixuxu
Copy link
Collaborator Author

@tin2tin
i will merge this soon, do you want to test it out? all the checkpoints are uploaded, you can find scripts in the PR description

@tin2tin
Copy link

tin2tin commented Nov 29, 2025

I don't have time right now, but I'll definitely check it out later.

@tin2tin
Copy link

tin2tin commented Nov 29, 2025

720p_t2v seems to be loading, but it's too heavy for me to run (using your example code) - it ended with a crash.

image

480p_t2v
image

HunyuanVideo 1.5 runs just fine on my setup in ComfyUI, so they must have found out how to optimize it seriously.

@yiyixuxu
Copy link
Collaborator Author

@tin2tin do you want to try with group offloading?

@tin2tin
Copy link

tin2tin commented Nov 30, 2025

I only have 32 GB RAM, that's usually not enough for group offloading.

I did try using single file pre-quantized file, but I couldn't get it working.

@yiyixuxu yiyixuxu requested a review from dg845 November 30, 2025 21:30
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comments are mostly nits! I can help with the tests.

self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor

def disable_tiling(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have the model class subclass from AutoencoderMixin to get rid of the common methods. Example:

class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):

rope_theta: float = 256.0,
rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
# YiYi Notes: config based on target_size_config https://github.com/yiyixuxu/hy15/blob/main/hyvideo/pipelines/hunyuan_video_pipeline.py#L205
target_size: int = 640, # did not name sample_size since it is in pixel spaces
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that important but we're still doing the AE decoding according to:

video = self.vae.decode(latents, return_dict=False)[0]

I got the impression that the DiT might directly predicting the pixel space.


@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can get rid of these common attention related methods if we subclass the model from AttentionMixin:

class Flux2Transformer2DModel(
ModelMixin,
ConfigMixin,
PeftAdapterMixin,
FromOriginalModelMixin,
FluxTransformer2DLoadersMixin,
CacheMixin,
AttentionMixin,
):

Comment on lines +714 to +725
encoder_hidden_states_cond_emb = self.cond_type_embed(
torch.zeros_like(encoder_hidden_states[:, :, 0], dtype=torch.long)
)
encoder_hidden_states = encoder_hidden_states + encoder_hidden_states_cond_emb

# byt5 text embedding
encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)

encoder_hidden_states_2_cond_emb = self.cond_type_embed(
torch.ones_like(encoder_hidden_states_2[:, :, 0], dtype=torch.long)
)
encoder_hidden_states_2 = encoder_hidden_states_2 + encoder_hidden_states_2_cond_emb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting that for encoder_hidden_states_cond_emb, we do zeros_likes and for encoder_hidden_states_2_cond_emb, we do ones_like.


# image embed
encoder_hidden_states_3 = self.image_embedder(image_embeds)
is_t2v = torch.all(image_embeds == 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit):

I would imagine that in case we're on pure T2V, image_embeds would be None. But if that's not the case (the current code suggests we might need image_embeds to be zeroes), we could change the type hint of image_embeds to image_embeds: torch.Tensor and make it a positional argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they actually create it all zero image_embeds for t2v

device=encoder_attention_mask.device,
)
encoder_hidden_states_3_cond_emb = self.cond_type_embed(
2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So, it's probably like:

  • First condition embedding type: 0
  • Second condition embedding type: 1
  • Third condition embedding type: 2

@sayakpaul
Copy link
Member

Started tests in #12759 :) Completing now.


- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.

- **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I will work on adding the Hub variant for FA3 varlen so that we can ease the user-experience a bit here.

Comment on lines 60 to 62
- **A100/A800/RTX 4090:** `flash` or `flash_varlen`
- **Other GPUs:** `sage`

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): We could recommend flash_hub and sage_hub here as backends instead to promote more the usage of the kernels-based backends. It will also keep things central to the Hub.

@yiyixuxu
Copy link
Collaborator Author

yiyixuxu commented Dec 1, 2025

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Dec 1, 2025

Style bot fixed some files and pushed the changes.

@yiyixuxu yiyixuxu merged commit 6156cf8 into main Dec 1, 2025
14 of 15 checks passed
@tin2tin
Copy link

tin2tin commented Dec 1, 2025

Congratulations on the commit. Do you have any suggestions on what I could try to get it working on 24 GB VRAM (RTX 4090) and 32 GB RAM?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants