-
Notifications
You must be signed in to change notification settings - Fork 148
[Inference] Add AdaSpa support for accelerating video generation #1158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| # Wan2.1 Inference with AdaSpa | ||
| ## Overview | ||
| This example shows how to run Wan2.1 inference with AdaSpa acceleration in FlagScale. | ||
| ## Prerequisites | ||
| - Linux + NVIDIA GPU (SM80+ recommended) | ||
| - CUDA toolkit available (`nvcc -V`) | ||
| - Python environment compatible with FlagScale inference | ||
| - Install FlagScale and inference dependencies: | ||
| ```bash | ||
| pip install . --verbose | ||
| pip install -r requirements/cuda/inference.txt | ||
| ``` | ||
|
|
||
| - AdaSpa extra dependency | ||
| AdaSpa relies on block sparse attention extension under: `flagscale/inference/core/diffusion/adaspa/third_party/block_sparse_attention` | ||
|
|
||
| ```bash | ||
| cd flagscale/inference/core/diffusion/adaspa/third_party/block_sparse_attention | ||
| pip install -e . | ||
| cd - | ||
| ``` | ||
| - Prepare model | ||
| Set your local Wan2.1 model path in: | ||
|
|
||
| `examples/wan2_1/conf/inference/1.3b.yaml` | ||
| `examples/wan2_1/conf/inference/1.3b_adaspa.yaml` | ||
| - Config selection | ||
| In `examples/wan2_1/conf/inference.yaml`: | ||
|
|
||
| ```bash | ||
| defaults: | ||
| - _self_ | ||
| - inference: 1.3b_adaspa (adaspa) | ||
| : 1.3b (taylorseer) | ||
|
|
||
| ``` | ||
| - Run inference | ||
| ```bash | ||
| flagscale inference wan2_1 -c examples/wan2_1/conf/inference.yaml | ||
| ``` | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||
| engine: | ||||||
| model: Wan-AI/Wan2.1-T2V-1.3B-Diffusers | ||||||
| model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers | ||||||
|
||||||
| model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers | |
| model: ${oc.env:WAN_MODEL_PATH, Wan-AI/Wan2.1-T2V-1.3B-Diffusers} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,59 @@ | ||
| engine: | ||
| model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers | ||
| loader: diffusers | ||
| pipeline: | ||
| class: diffusers.WanPipeline | ||
| from_pretrained: | ||
| torch_dtype: bfloat16 | ||
| components: | ||
| vae: | ||
| class: diffusers.AutoencoderKLWan | ||
| from_pretrained: | ||
| subfolder: vae | ||
| torch_dtype: float32 | ||
| device: cuda | ||
| results_path: ${experiment.exp_dir}/results | ||
| output_format: "video" | ||
| transformations: | ||
| AdaSpaTransformation: | ||
| strict: true | ||
| basic: | ||
| model_id: Wan2.1 | ||
| prompt: "A cat walks on the grass, realistic" | ||
| height: 480 | ||
| width: 832 | ||
| frames: 81 | ||
| num_steps: 50 | ||
| seed: 42 | ||
| fps: 15 | ||
| adaspa: | ||
| enable_log: true | ||
| num_layers: 30 | ||
| sparsity: 0.9 | ||
| search_steps: [0, 30] | ||
| min_recall: 0.9 | ||
| block_size: 128 | ||
| sparsity_modes: | ||
| - random_select | ||
| - head_adaptive | ||
| - cache_lse | ||
| - row_wise | ||
| - text_sink | ||
| - first_frame_sink | ||
| TimestepTrackerTransformation: | ||
| {} | ||
| StateScopeTransformation: | ||
| {} | ||
| state_scopes: | ||
| ["cond", "uncond"] | ||
|
|
||
| generate: | ||
| prompts: ["A cat walks on the grass, realistic"] | ||
| negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" | ||
| height: 480 | ||
| width: 832 | ||
| num_frames: 81 | ||
| guidance_scale: 5.0 | ||
| generator: | ||
| seed: 42 | ||
| device: cuda |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| engine: | ||
| model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers | ||
| loader: diffusers | ||
| pipeline: | ||
| class: diffusers.WanPipeline | ||
| from_pretrained: | ||
| torch_dtype: bfloat16 | ||
| components: | ||
| vae: | ||
| class: diffusers.AutoencoderKLWan | ||
| from_pretrained: | ||
| subfolder: vae | ||
| torch_dtype: float32 | ||
| device: cuda | ||
| results_path: ${experiment.exp_dir}/results | ||
| output_format: "video" | ||
| transformations: | ||
| AdaSpaTransformation: | ||
| strict: true | ||
| basic: | ||
| model_id: Wan2.1 | ||
| prompt: "A cat walks on the grass, realistic" | ||
| height: 480 | ||
| width: 832 | ||
| frames: 81 | ||
| num_steps: 50 | ||
| seed: 42 | ||
| fps: 15 | ||
| adaspa: | ||
| enable_log: true | ||
| num_layers: 30 | ||
| sparsity: 0.9 | ||
| search_steps: [0, 30] | ||
| min_recall: 0.9 | ||
| block_size: 128 | ||
| sparsity_modes: | ||
| - random_select | ||
| - head_adaptive | ||
| - cache_lse | ||
| - row_wise | ||
| - text_sink | ||
| - first_frame_sink | ||
| TaylorSeerTransformation: | ||
| order: 1 | ||
| warmup_steps: 4 | ||
| skip_interval_steps: 3 | ||
| targets: | ||
| by_name: ["*blocks.*.attn1", "*blocks.*.attn2", "*blocks.*.ffn"] | ||
| use_timestep_delta: true | ||
| TimestepTrackerTransformation: | ||
| {} | ||
| StateScopeTransformation: | ||
| {} | ||
| state_scopes: | ||
| ["cond", "uncond"] | ||
|
|
||
| generate: | ||
| prompts: ["A cat walks on the grass, realistic"] | ||
| negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" | ||
| height: 480 | ||
| width: 832 | ||
| num_frames: 81 | ||
| guidance_scale: 5.0 | ||
| generator: | ||
| seed: 42 | ||
| device: cuda |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from .adaspa_handler import AdaSpaHandler | ||
| from .processor import WanAdaSpaAttnProcessor | ||
|
|
||
| # Optional explicit registration when the symbol exists in current diffusers. | ||
| try: | ||
| from diffusers.models.attention_processor import WanAttnProcessor2_0 | ||
|
|
||
| AdaSpaHandler.register_processor(WanAttnProcessor2_0, WanAdaSpaAttnProcessor) | ||
| except Exception: | ||
| pass | ||
|
|
||
| __all__ = ["AdaSpaHandler", "WanAdaSpaAttnProcessor"] |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,108 @@ | ||||||||||
| from typing import Dict, Type, Optional, Any, Tuple | ||||||||||
| import torch | ||||||||||
| from torch import nn | ||||||||||
| from diffusers.models.attention import Attention | ||||||||||
|
|
||||||||||
| ADASPA_PROCESSOR = None | ||||||||||
|
|
||||||||||
| def get_model_name(): | ||||||||||
| global ADASPA_PROCESSOR | ||||||||||
| name_dict = { | ||||||||||
| "HunyuanVideoAdaSpaAttnProcessor" : "HunyuanVideo", | ||||||||||
| "CogVideoXAdaSpaAttnProcessor" : "CogVideoX", | ||||||||||
| "WanAdaSpaAttnProcessor": "Wan2.1", | ||||||||||
| "Wan2.1": "Wan2.1", | ||||||||||
| } | ||||||||||
| if ADASPA_PROCESSOR in name_dict: | ||||||||||
| return name_dict[ADASPA_PROCESSOR] | ||||||||||
| # Fallback for unknown processor names: keep behavior safe for new model adapters. | ||||||||||
| return "Wan2.1" | ||||||||||
|
|
||||||||||
| class AdaSpaHandler(nn.Module): | ||||||||||
| """ | ||||||||||
| A handler class that provides a generic way to replace attention processors with their sparse versions. | ||||||||||
| It maintains a registry of original attention processors and their corresponding sparse versions. | ||||||||||
| """ | ||||||||||
| _processor_registry: Dict[Type, Type] = {} | ||||||||||
| adaspa_processor = None | ||||||||||
|
|
||||||||||
| @classmethod | ||||||||||
| def register_processor(cls, original_processor: Type, sparse_processor: Type) -> None: | ||||||||||
| """ | ||||||||||
| Register a mapping between an original attention processor and its sparse version. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| original_processor: The original attention processor class | ||||||||||
| sparse_processor: The corresponding sparse attention processor class | ||||||||||
| """ | ||||||||||
| cls._processor_registry[original_processor] = sparse_processor | ||||||||||
|
|
||||||||||
| def __init__( | ||||||||||
| self, | ||||||||||
| model: nn.Module, | ||||||||||
| **kwargs: Any | ||||||||||
| ): | ||||||||||
| """ | ||||||||||
| Initialize the AdaSpaHandler with a model and optional shape information. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| model: The model whose attention processors will be replaced | ||||||||||
| **kwargs: Additional arguments to pass to the sparse processor constructor | ||||||||||
| """ | ||||||||||
| super().__init__() | ||||||||||
| self.model = model | ||||||||||
| self.kwargs = kwargs | ||||||||||
| self._original_processors = {} | ||||||||||
| self._sparse_processors = {} | ||||||||||
| self._replace_processors() | ||||||||||
|
|
||||||||||
| def _replace_processors(self) -> None: | ||||||||||
| """Replace attention processors in Attention modules.""" | ||||||||||
|
|
||||||||||
| def _replace_module(module: nn.Module) -> None: | ||||||||||
| if isinstance(module, Attention): | ||||||||||
| if hasattr(module, 'processor'): | ||||||||||
| original_processor = module.processor | ||||||||||
| original_type = type(original_processor) | ||||||||||
|
|
||||||||||
| if original_type in self._processor_registry: | ||||||||||
| sparse_type = self._processor_registry[original_type] | ||||||||||
|
|
||||||||||
| global ADASPA_PROCESSOR | ||||||||||
| ADASPA_PROCESSOR = sparse_type.__name__ | ||||||||||
|
|
||||||||||
| # Store original processor | ||||||||||
| key = f"{module.__class__.__name__}.processor" | ||||||||||
|
Comment on lines
+74
to
+75
|
||||||||||
| # Store original processor | |
| key = f"{module.__class__.__name__}.processor" | |
| # Store original processor using a unique key per module instance | |
| key = f"{module.__class__.__name__}-{id(module)}.processor" |
Copilot
AI
Mar 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
WanAdaSpaAttnProcessor requires a base_processor instance (it raises at call time when base_processor is None), but the handler constructs it with no arguments. Construct the sparse processor with the original processor wired in (e.g., sparse_type(base_processor=original_processor)), and consider forwarding any handler kwargs through as well.
| sparse_processor = sparse_type() | |
| sparse_processor = sparse_type(base_processor=original_processor, **self.kwargs) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .attn_func import AdaptiveSparseAttention, adaptive_sparse_attn | ||
|
|
||
| __all__ = ["AdaptiveSparseAttention", "adaptive_sparse_attn"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The 'Config selection' snippet is labeled as bash but shows what looks like a Hydra YAML fragment, and the
- inference: 1.3b_adaspa (adaspa)/: 1.3b (taylorseer)syntax is not valid YAML. Please replace this with a validdefaultsexample (and mark the code block asyaml) so users can copy-paste it successfully.