Skip to content

Commit 376b206

Browse files
committed
add initial support for sparse attention
Signed-off-by: Kai Xu <[email protected]>
1 parent 90e6638 commit 376b206

File tree

14 files changed

+2250
-0
lines changed

14 files changed

+2250
-0
lines changed
Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,368 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Example script for applying sparse attention to HuggingFace models."""
18+
19+
import argparse
20+
import random
21+
from pathlib import Path
22+
23+
import numpy as np
24+
import torch
25+
import torch.nn as nn
26+
from datasets import load_dataset
27+
from transformers import AutoModelForCausalLM, AutoTokenizer
28+
29+
import modelopt.torch.sparsity.attention_sparsity as mtsa
30+
from modelopt.torch.export import export_hf_checkpoint
31+
from modelopt.torch.sparsity.attention_sparsity import SparseAttentionConfig
32+
from modelopt.torch.sparsity.attention_sparsity.config import (
33+
SKIP_SOFTMAX_CALIB,
34+
SKIP_SOFTMAX_DEFAULT,
35+
)
36+
from modelopt.torch.sparsity.attention_sparsity.nn.sparse_attention import SparseAttentionModule
37+
from modelopt.torch.utils.memory_monitor import launch_memory_monitor
38+
39+
RAND_SEED = 1234
40+
41+
# You can define custom configurations or use the default
42+
SPARSE_ATTN_CFG_CHOICES = {
43+
"skip_softmax": SKIP_SOFTMAX_DEFAULT,
44+
"skip_softmax_calib": SKIP_SOFTMAX_CALIB,
45+
}
46+
47+
48+
def print_sparsity_stats(model: nn.Module):
49+
"""Print sparsity statistics if available."""
50+
module_stats = []
51+
for name, module in model.named_modules():
52+
if hasattr(module, "get_stats"):
53+
stats = module.get_stats()
54+
if stats and "average_sparsity" in stats:
55+
module_stats.append((name, stats["average_sparsity"]))
56+
57+
if not module_stats:
58+
print("No sparsity statistics available")
59+
return
60+
61+
# Check if all modules have the same sparsity
62+
sparsities = [s for _, s in module_stats]
63+
if len(set(sparsities)) == 1:
64+
# All identical - show summary
65+
print(f"Average sparsity across all {len(module_stats)} modules: {sparsities[0]:.2%}")
66+
else:
67+
# Different sparsities - show individual values
68+
avg_sparsity = sum(sparsities) / len(sparsities)
69+
print(f"Average sparsity: {avg_sparsity:.2%}")
70+
print("Per-module breakdown:")
71+
for name, sparsity in module_stats:
72+
print(f" {name}: {sparsity:.2%} sparse")
73+
74+
75+
def get_narrativeqa_samples(num_samples=3):
76+
"""Load samples from NarrativeQA dataset for testing.
77+
78+
Args:
79+
num_samples: Number of samples to generate
80+
"""
81+
# Load NarrativeQA dataset
82+
dataset = load_dataset("narrativeqa", split="test", streaming=True)
83+
84+
samples = []
85+
for i, item in enumerate(dataset):
86+
if i >= num_samples:
87+
break
88+
89+
# Combine document context and question
90+
context = item.get("document", {}).get("text", "")
91+
question = item.get("question", {}).get("text", "")
92+
93+
if context and question:
94+
# Use the full context as-is
95+
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
96+
samples.append(prompt)
97+
98+
if not samples:
99+
raise ValueError("Could not load NarrativeQA samples")
100+
101+
print(f"Loaded {len(samples)} NarrativeQA samples")
102+
return samples
103+
104+
105+
def truncate_text(text: str, tokenizer, max_length: int):
106+
"""Truncate text from the middle to preserve beginning and end.
107+
108+
Args:
109+
text: Input text to truncate
110+
tokenizer: Tokenizer to use for encoding
111+
max_length: Maximum number of tokens
112+
113+
Returns:
114+
Truncated text that fits within max_length tokens
115+
"""
116+
# First tokenize to see if truncation is needed
117+
tokens = tokenizer.encode(text, add_special_tokens=True)
118+
119+
if len(tokens) <= max_length:
120+
return text
121+
122+
# Need to truncate - preserve beginning and end
123+
# Reserve some tokens for special tokens
124+
available_tokens = max_length - 2 # Account for special tokens
125+
126+
# Split tokens roughly in half for beginning and end
127+
begin_tokens = available_tokens // 2
128+
end_tokens = available_tokens - begin_tokens
129+
130+
# Decode beginning and end parts
131+
begin_text = tokenizer.decode(tokens[:begin_tokens], skip_special_tokens=True)
132+
end_text = tokenizer.decode(tokens[-end_tokens:], skip_special_tokens=True)
133+
134+
# Combine with ellipsis marker
135+
return begin_text + " [...] " + end_text
136+
137+
138+
def verify_outputs(model, tokenizer, args):
139+
"""Compare outputs between baseline and sparse attention models."""
140+
# Update seq_len to match calibration max_seqlen if calibration was used
141+
base_config = SPARSE_ATTN_CFG_CHOICES.get(args.sparse_attn, {})
142+
if "calibration" in base_config and "max_seqlen" in base_config["calibration"]:
143+
calib_max_seqlen = base_config["calibration"]["max_seqlen"]
144+
if args.seq_len != calib_max_seqlen:
145+
print(
146+
f"\nNote: Updating test seq_len from {args.seq_len} to {calib_max_seqlen} "
147+
f"to match calibration config"
148+
)
149+
args.seq_len = calib_max_seqlen
150+
151+
# Load and prepare a single test prompt
152+
print(f"\nLoading test sample (will be tokenized up to {args.seq_len} tokens)")
153+
prompts = get_narrativeqa_samples(num_samples=1)
154+
prompt = prompts[0]
155+
156+
# Prepare inputs
157+
truncated_prompt = truncate_text(prompt, tokenizer, args.seq_len)
158+
display_prompt = (
159+
truncated_prompt[:150] + "..." if len(truncated_prompt) > 150 else truncated_prompt
160+
)
161+
162+
inputs = tokenizer(
163+
truncated_prompt,
164+
return_tensors="pt",
165+
max_length=args.seq_len,
166+
truncation=True,
167+
padding=False,
168+
)
169+
if torch.cuda.is_available():
170+
inputs = {k: v.cuda() for k, v in inputs.items()}
171+
172+
print("\n" + "=" * 60)
173+
print("BASELINE vs SPARSE ATTENTION COMPARISON")
174+
print("=" * 60)
175+
print(f"\nTest prompt: {display_prompt}")
176+
print(f"Input tokens: {inputs['input_ids'].shape[1]} (max: {args.seq_len})")
177+
if "[...]" in truncated_prompt:
178+
print("Note: Text was middle-truncated to fit token limit")
179+
180+
# Helper function to generate text
181+
def generate_text(model, inputs, args, tokenizer):
182+
with torch.no_grad():
183+
outputs = model.generate(
184+
**inputs,
185+
max_new_tokens=args.max_new_tokens,
186+
do_sample=args.do_sample,
187+
temperature=args.temperature if args.do_sample else 1.0,
188+
pad_token_id=tokenizer.pad_token_id,
189+
)
190+
input_length = inputs["input_ids"].shape[1]
191+
generated_ids = outputs[0][input_length:]
192+
return tokenizer.decode(generated_ids, skip_special_tokens=True)
193+
194+
# Find all sparse attention modules
195+
sparse_modules = [m for m in model.modules() if isinstance(m, SparseAttentionModule)]
196+
197+
# Generate baseline by temporarily disabling sparse attention
198+
print("\n" + "-" * 60)
199+
print("Generating baseline (sparse attention disabled)...")
200+
for module in sparse_modules:
201+
module.disable()
202+
baseline_text = generate_text(model, inputs, args, tokenizer)
203+
204+
# Generate with sparse attention enabled
205+
print("\nGenerating with sparse attention (calibrated thresholds)...")
206+
for module in sparse_modules:
207+
module.enable()
208+
sparse_text = generate_text(model, inputs, args, tokenizer)
209+
210+
# Display comparison
211+
print("\n" + "-" * 60)
212+
print("RESULTS:")
213+
baseline_display = baseline_text[:300] + "..." if len(baseline_text) > 300 else baseline_text
214+
sparse_display = sparse_text[:300] + "..." if len(sparse_text) > 300 else sparse_text
215+
216+
print(f"\nBaseline: {baseline_display}")
217+
print(f"With Sparse: {sparse_display}")
218+
219+
if baseline_text == sparse_text:
220+
print("\nOutputs are identical")
221+
else:
222+
print("\nOutputs differ")
223+
224+
225+
def sparsify_model(model, args):
226+
"""Apply sparse attention to the model with optional calibration."""
227+
print(f"\nApplying sparse attention: {args.sparse_attn} with backend: {args.backend}")
228+
base_config = SPARSE_ATTN_CFG_CHOICES[args.sparse_attn]
229+
230+
# Create modified config with selected backend
231+
modified_sparse_cfg = {}
232+
for pattern, cfg in base_config["sparse_cfg"].items():
233+
modified_cfg = cfg.copy()
234+
modified_cfg["backend"] = args.backend
235+
modified_sparse_cfg[pattern] = modified_cfg
236+
237+
# Create new config with modified settings
238+
sparse_config = SparseAttentionConfig(
239+
method=base_config["method"],
240+
sparse_cfg=modified_sparse_cfg,
241+
collect_stats=True, # Enable stats collection for monitoring
242+
)
243+
244+
# Sparsify with optional calibration - framework handles calibration automatically
245+
model = mtsa.sparsify(model, config=sparse_config)
246+
247+
print("Sparse attention applied successfully!")
248+
249+
# Show sparsity statistics
250+
print("\n" + "=" * 60)
251+
print("Sparsity Statistics")
252+
print("=" * 60)
253+
print_sparsity_stats(model)
254+
255+
return model
256+
257+
258+
def main(args):
259+
"""Main function to run the selected mode."""
260+
if not torch.cuda.is_available():
261+
raise OSError("GPU is required for inference.")
262+
263+
random.seed(RAND_SEED)
264+
np.random.seed(RAND_SEED)
265+
launch_memory_monitor()
266+
267+
print(f"Loading model: {args.pyt_ckpt_path}")
268+
269+
# Load model and tokenizer
270+
# Note: attn_implementation="eager" is required for calibration to work properly
271+
# (flash_attention_2 or sdpa would bypass the softmax patching needed for stats collection)
272+
model = AutoModelForCausalLM.from_pretrained(
273+
args.pyt_ckpt_path,
274+
attn_implementation="eager",
275+
torch_dtype=torch.bfloat16,
276+
)
277+
tokenizer = AutoTokenizer.from_pretrained(args.pyt_ckpt_path)
278+
279+
# Set pad token if not set
280+
if tokenizer.pad_token is None:
281+
tokenizer.pad_token = tokenizer.eos_token
282+
283+
# Move model to GPU if available
284+
if torch.cuda.is_available():
285+
model = model.cuda()
286+
print("Model moved to CUDA")
287+
288+
# Apply sparse attention to the model (with calibration if configured)
289+
model = sparsify_model(model, args)
290+
291+
# Verify outputs if requested (compares baseline vs calibrated sparse model)
292+
if args.verify_output:
293+
verify_outputs(model, tokenizer, args)
294+
295+
# Export if requested
296+
if args.export_dir:
297+
print(f"\nExporting model to: {args.export_dir}")
298+
export_dir = Path(args.export_dir)
299+
export_dir.mkdir(parents=True, exist_ok=True)
300+
301+
with torch.inference_mode():
302+
export_hf_checkpoint(model, export_dir=export_dir)
303+
304+
tokenizer.save_pretrained(export_dir)
305+
print(f"Model exported successfully to: {export_dir}")
306+
307+
308+
if __name__ == "__main__":
309+
parser = argparse.ArgumentParser(description=__doc__)
310+
311+
# Model arguments
312+
parser.add_argument(
313+
"--pyt_ckpt_path",
314+
type=str,
315+
required=True,
316+
help="Specify where the PyTorch checkpoint path is",
317+
)
318+
parser.add_argument(
319+
"--sparse_attn",
320+
type=str,
321+
default="skip_softmax",
322+
choices=list(SPARSE_ATTN_CFG_CHOICES.keys()),
323+
help="Sparse attention configuration to apply.",
324+
)
325+
parser.add_argument(
326+
"--backend",
327+
type=str,
328+
default="pytorch",
329+
choices=["pytorch", "triton"],
330+
help="Backend to use for sparse attention computation (default: pytorch)",
331+
)
332+
333+
# Sequence length arguments
334+
parser.add_argument(
335+
"--seq_len",
336+
type=int,
337+
default=2048,
338+
help="Maximum sequence length for input prompts (will be truncated if longer)",
339+
)
340+
parser.add_argument(
341+
"--num_samples",
342+
type=int,
343+
default=3,
344+
help="Number of samples to use from NarrativeQA dataset",
345+
)
346+
347+
# Generation arguments
348+
parser.add_argument(
349+
"--max_new_tokens", type=int, default=50, help="Maximum new tokens to generate"
350+
)
351+
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
352+
parser.add_argument("--temperature", type=float, default=0.7, help="Temperature for sampling")
353+
354+
# Operation arguments
355+
parser.add_argument(
356+
"--verify_output",
357+
action="store_true",
358+
help="Verify that sparse attention outputs match baseline",
359+
)
360+
parser.add_argument(
361+
"--export_dir",
362+
type=str,
363+
default=None,
364+
help="Directory to export the model with sparse attention applied",
365+
)
366+
367+
args = parser.parse_args()
368+
main(args)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Extensible sparse attention optimization for transformer models."""
17+
18+
# Initialize mode
19+
from . import mode
20+
21+
# Add methods to namespace
22+
from .config import *
23+
from .conversion import *
24+
from .model_sparsify import *

0 commit comments

Comments
 (0)