Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
41 changes: 41 additions & 0 deletions examples/wan2_1/conf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Wan2.1 Inference with AdaSpa
## Overview
This example shows how to run Wan2.1 inference with AdaSpa acceleration in FlagScale.
## Prerequisites
- Linux + NVIDIA GPU (SM80+ recommended)
- CUDA toolkit available (`nvcc -V`)
- Python environment compatible with FlagScale inference
- Install FlagScale and inference dependencies:
```bash
pip install . --verbose
pip install -r requirements/cuda/inference.txt
```

- AdaSpa extra dependency
AdaSpa relies on block sparse attention extension under: `flagscale/inference/core/diffusion/adaspa/third_party/block_sparse_attention`

```bash
cd flagscale/inference/core/diffusion/adaspa/third_party/block_sparse_attention
pip install -e .
cd -
```
- Prepare model
Set your local Wan2.1 model path in:

`examples/wan2_1/conf/inference/1.3b.yaml`
`examples/wan2_1/conf/inference/1.3b_adaspa.yaml`
- Config selection
In `examples/wan2_1/conf/inference.yaml`:

```bash
defaults:
- _self_
- inference: 1.3b_adaspa (adaspa)
: 1.3b (taylorseer)

Comment on lines +30 to +35
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'Config selection' snippet is labeled as bash but shows what looks like a Hydra YAML fragment, and the - inference: 1.3b_adaspa (adaspa) / : 1.3b (taylorseer) syntax is not valid YAML. Please replace this with a valid defaults example (and mark the code block as yaml) so users can copy-paste it successfully.

Suggested change
```bash
defaults:
- _self_
- inference: 1.3b_adaspa (adaspa)
: 1.3b (taylorseer)
```yaml
defaults:
- _self_
- inference: 1.3b_adaspa # use AdaSpa-accelerated config
# To use the standard Wan2.1 config instead, comment the line above and uncomment the line below:
# - inference: 1.3b # use non-AdaSpa config

Copilot uses AI. Check for mistakes.
```
- Run inference
```bash
flagscale inference wan2_1 -c examples/wan2_1/conf/inference.yaml
```

4 changes: 2 additions & 2 deletions examples/wan2_1/conf/inference.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- _self_
- inference: 1.3b

experiment:
exp_name: wan2_1
exp_dir: outputs/${experiment.exp_name}
Expand All @@ -13,7 +13,7 @@ experiment:
runner:
hostfile: null
cmds:
before_start: source /root/miniconda3/bin/activate flagscale-inference
before_start: ""
envs:
CUDA_VISIBLE_DEVICES: 0
CUDA_DEVICE_MAX_CONNECTIONS: 1
Expand Down
2 changes: 1 addition & 1 deletion examples/wan2_1/conf/inference/1.3b.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
engine:
model: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the model identifier from a portable Hub ID to a machine-specific absolute path. For a repo example config, prefer keeping the Hub ID (or using an env-var/placeholder like ${oc.env:WAN_MODEL_PATH, Wan-AI/Wan2.1-T2V-1.3B-Diffusers}) so the config works out-of-the-box across environments while still allowing local overrides.

Suggested change
model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers
model: ${oc.env:WAN_MODEL_PATH, Wan-AI/Wan2.1-T2V-1.3B-Diffusers}

Copilot uses AI. Check for mistakes.
loader: diffusers
pipeline:
class: diffusers.WanPipeline
Expand Down
59 changes: 59 additions & 0 deletions examples/wan2_1/conf/inference/1.3b_adaspa.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
engine:
model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers
loader: diffusers
pipeline:
class: diffusers.WanPipeline
from_pretrained:
torch_dtype: bfloat16
components:
vae:
class: diffusers.AutoencoderKLWan
from_pretrained:
subfolder: vae
torch_dtype: float32
device: cuda
results_path: ${experiment.exp_dir}/results
output_format: "video"
transformations:
AdaSpaTransformation:
strict: true
basic:
model_id: Wan2.1
prompt: "A cat walks on the grass, realistic"
height: 480
width: 832
frames: 81
num_steps: 50
seed: 42
fps: 15
adaspa:
enable_log: true
num_layers: 30
sparsity: 0.9
search_steps: [0, 30]
min_recall: 0.9
block_size: 128
sparsity_modes:
- random_select
- head_adaptive
- cache_lse
- row_wise
- text_sink
- first_frame_sink
TimestepTrackerTransformation:
{}
StateScopeTransformation:
{}
state_scopes:
["cond", "uncond"]

generate:
prompts: ["A cat walks on the grass, realistic"]
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
height: 480
width: 832
num_frames: 81
guidance_scale: 5.0
generator:
seed: 42
device: cuda
66 changes: 66 additions & 0 deletions examples/wan2_1/conf/inference/1.3b_combine.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
engine:
model: /workspace/models/Wan2.1-T2V-1.3B-Diffusers
loader: diffusers
pipeline:
class: diffusers.WanPipeline
from_pretrained:
torch_dtype: bfloat16
components:
vae:
class: diffusers.AutoencoderKLWan
from_pretrained:
subfolder: vae
torch_dtype: float32
device: cuda
results_path: ${experiment.exp_dir}/results
output_format: "video"
transformations:
AdaSpaTransformation:
strict: true
basic:
model_id: Wan2.1
prompt: "A cat walks on the grass, realistic"
height: 480
width: 832
frames: 81
num_steps: 50
seed: 42
fps: 15
adaspa:
enable_log: true
num_layers: 30
sparsity: 0.9
search_steps: [0, 30]
min_recall: 0.9
block_size: 128
sparsity_modes:
- random_select
- head_adaptive
- cache_lse
- row_wise
- text_sink
- first_frame_sink
TaylorSeerTransformation:
order: 1
warmup_steps: 4
skip_interval_steps: 3
targets:
by_name: ["*blocks.*.attn1", "*blocks.*.attn2", "*blocks.*.ffn"]
use_timestep_delta: true
TimestepTrackerTransformation:
{}
StateScopeTransformation:
{}
state_scopes:
["cond", "uncond"]

generate:
prompts: ["A cat walks on the grass, realistic"]
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
height: 480
width: 832
num_frames: 81
guidance_scale: 5.0
generator:
seed: 42
device: cuda
12 changes: 12 additions & 0 deletions flagscale/inference/core/diffusion/adaspa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .adaspa_handler import AdaSpaHandler
from .processor import WanAdaSpaAttnProcessor

# Optional explicit registration when the symbol exists in current diffusers.
try:
from diffusers.models.attention_processor import WanAttnProcessor2_0

AdaSpaHandler.register_processor(WanAttnProcessor2_0, WanAdaSpaAttnProcessor)
except Exception:
pass

__all__ = ["AdaSpaHandler", "WanAdaSpaAttnProcessor"]
108 changes: 108 additions & 0 deletions flagscale/inference/core/diffusion/adaspa/adaspa_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Dict, Type, Optional, Any, Tuple
import torch
from torch import nn
from diffusers.models.attention import Attention

ADASPA_PROCESSOR = None

def get_model_name():
global ADASPA_PROCESSOR
name_dict = {
"HunyuanVideoAdaSpaAttnProcessor" : "HunyuanVideo",
"CogVideoXAdaSpaAttnProcessor" : "CogVideoX",
"WanAdaSpaAttnProcessor": "Wan2.1",
"Wan2.1": "Wan2.1",
}
if ADASPA_PROCESSOR in name_dict:
return name_dict[ADASPA_PROCESSOR]
# Fallback for unknown processor names: keep behavior safe for new model adapters.
return "Wan2.1"

class AdaSpaHandler(nn.Module):
"""
A handler class that provides a generic way to replace attention processors with their sparse versions.
It maintains a registry of original attention processors and their corresponding sparse versions.
"""
_processor_registry: Dict[Type, Type] = {}
adaspa_processor = None

@classmethod
def register_processor(cls, original_processor: Type, sparse_processor: Type) -> None:
"""
Register a mapping between an original attention processor and its sparse version.

Args:
original_processor: The original attention processor class
sparse_processor: The corresponding sparse attention processor class
"""
cls._processor_registry[original_processor] = sparse_processor

def __init__(
self,
model: nn.Module,
**kwargs: Any
):
"""
Initialize the AdaSpaHandler with a model and optional shape information.

Args:
model: The model whose attention processors will be replaced
**kwargs: Additional arguments to pass to the sparse processor constructor
"""
super().__init__()
self.model = model
self.kwargs = kwargs
self._original_processors = {}
self._sparse_processors = {}
self._replace_processors()

def _replace_processors(self) -> None:
"""Replace attention processors in Attention modules."""

def _replace_module(module: nn.Module) -> None:
if isinstance(module, Attention):
if hasattr(module, 'processor'):
original_processor = module.processor
original_type = type(original_processor)

if original_type in self._processor_registry:
sparse_type = self._processor_registry[original_type]

global ADASPA_PROCESSOR
ADASPA_PROCESSOR = sparse_type.__name__

# Store original processor
key = f"{module.__class__.__name__}.processor"
Comment on lines +74 to +75
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key used to track original/sparse processors is not unique per Attention module (multiple Attention instances share the same class name), so later entries will overwrite earlier ones. Use a unique identifier (e.g., module path name from model.named_modules(), or id(module)) to avoid collisions and make restoration/debugging reliable.

Suggested change
# Store original processor
key = f"{module.__class__.__name__}.processor"
# Store original processor using a unique key per module instance
key = f"{module.__class__.__name__}-{id(module)}.processor"

Copilot uses AI. Check for mistakes.
self._original_processors[key] = original_processor

# Create sparse processor wrapper around original processor
sparse_processor = sparse_type()
Copy link

Copilot AI Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WanAdaSpaAttnProcessor requires a base_processor instance (it raises at call time when base_processor is None), but the handler constructs it with no arguments. Construct the sparse processor with the original processor wired in (e.g., sparse_type(base_processor=original_processor)), and consider forwarding any handler kwargs through as well.

Suggested change
sparse_processor = sparse_type()
sparse_processor = sparse_type(base_processor=original_processor, **self.kwargs)

Copilot uses AI. Check for mistakes.
self._sparse_processors[key] = sparse_processor
module.processor = sparse_processor

for child in module.children():
_replace_module(child)

_replace_module(self.model)

def forward(self, *args, **kwargs):
"""Forward pass through the wrapped model."""
return self.model(*args, **kwargs)

def __getattr__(self, name: str) -> Any:
"""Delegate all other attributes to the wrapped model."""
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.model, name)

# Register Wan processor mapping for the in-tree AdaSpa package.
from .processor import WanAdaSpaAttnProcessor

# Optional Wan registration (depends on diffusers version naming).
try:
from diffusers.models.attention_processor import WanAttnProcessor2_0

AdaSpaHandler.register_processor(WanAttnProcessor2_0, WanAdaSpaAttnProcessor)
except Exception:
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .attn_func import AdaptiveSparseAttention, adaptive_sparse_attn

__all__ = ["AdaptiveSparseAttention", "adaptive_sparse_attn"]
Loading