-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathtest_fp8_kernel_e2e.py
More file actions
executable file
·96 lines (76 loc) · 2.95 KB
/
test_fp8_kernel_e2e.py
File metadata and controls
executable file
·96 lines (76 loc) · 2.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
"""端到端测试:使用FP8 kernel(在kernel内部进行FP8到BF16转换)"""
import os
import time
from transformers import AutoTokenizer
from diffulex import Diffulex, SamplingParams
def main():
# 模型配置
model = "/data1/ckpts/Dream-org/Dream-v0-Base-7B"
print("=" * 60)
print("初始化 Diffulex 模型 (FP8 KV Cache with FP8 Kernel)...")
print("=" * 60)
llm = Diffulex(
model,
lora_path="/data1/ckpts/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora",
use_lora=True,
model_name="dream",
enforce_eager=True,
data_parallel_size=1,
tensor_parallel_size=1,
gpu_memory_utilization=0.25,
max_num_batched_tokens=2048,
max_num_seqs=10,
max_model_len=2048,
accept_threshold=0.95,
complete_threshold=0.9,
add_new_block_threshold=0.1,
kv_cache_layout="unified", # FP8 kernel只支持unified layout
kv_cache_dtype="fp8_e4m3", # 使用FP8 KV cache
decoding_strategy="d2f"
)
print("✓ 模型初始化完成 (FP8 KV Cache with FP8 Kernel)\n")
# 加载 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.0, max_tokens=128)
# 定义几个测试 prompt
test_prompts = [
"The capital of France is",
"1 + 1 equals",
"Python is a programming language that",
]
# 添加 BOS token
prompts = [tokenizer.bos_token + p for p in test_prompts]
print("=" * 60)
print(f"运行生成测试 ({len(prompts)} 个 prompt)...")
print("使用FP8 KV cache,FP8 kernel在内部进行转换")
print("=" * 60)
start_time = time.time()
outputs = llm.generate(prompts, sampling_params)
end_time = time.time()
print("\n" + "=" * 60)
print("生成结果:")
print("=" * 60)
total_tokens = sum(len(o['token_ids']) for o in outputs)
total_time = end_time - start_time
avg_tps = total_tokens / total_time if total_time > 0 else 0
avg_diff_steps = sum(o['n_diff_steps'] for o in outputs) / len(outputs) if outputs else 0
print(f"\n总计:")
print(f" - 生成输出数: {len(outputs)}")
print(f" - 总 token 数: {total_tokens}")
print(f" - 总时间: {total_time:.2f} 秒")
print(f" - 平均 TPS: {avg_tps:.2f} tok/s")
print(f" - 平均扩散步数: {avg_diff_steps:.2f}")
print("\n" + "=" * 60)
print("详细输出:")
print("=" * 60)
for idx, (prompt, output) in enumerate(zip(test_prompts, outputs)):
print(f"\n[Prompt {idx + 1}]")
print(f"输入: {prompt}")
print(f"输出: {output['text']}")
print(f"Token IDs 长度: {len(output['token_ids'])}")
print(f"扩散步数: {output['n_diff_steps']}")
print("-" * 60)
print("\n✓ FP8 Kernel 端到端测试完成!")
if __name__ == "__main__":
main()