Skip to content

Commit 54d93c8

Browse files
committed
fix ppl
1 parent b155d3f commit 54d93c8

File tree

4 files changed

+274
-9
lines changed

4 files changed

+274
-9
lines changed

scripts/launch_server.py

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,21 @@ def parse_args():
7676
f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs."
7777
)
7878

79-
def chunk_json(id_, content=None, role=None, finish_reason=None):
79+
def chunk_json(id_, content=None, role=None, finish_reason=None, logprobs=None):
8080
delta = {}
8181
if content:
8282
delta["content"] = content
8383
if role:
8484
delta["role"] = role
85+
86+
# 构建 logprobs 对象
87+
logprobs_obj = None
88+
if logprobs is not None:
89+
logprobs_obj = {
90+
"content": logprobs.get("content", []),
91+
"refusal": None
92+
}
93+
8594
return {
8695
"id": id_,
8796
"object": "chat.completion.chunk",
@@ -92,7 +101,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
92101
{
93102
"index": 0,
94103
"delta": delta,
95-
"logprobs": None,
104+
"logprobs": logprobs_obj,
96105
"finish_reason": finish_reason,
97106
}
98107
],
@@ -101,14 +110,18 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
101110

102111
# A wrapper for InferTask that supports async output queue
103112
class AsyncInferTask(InferTask):
104-
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens):
113+
def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens, enable_logprobs=False):
105114
super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens)
106115
self.output_queue = janus.Queue()
107-
print(f"[INFO] Create InferTask {self.id}")
116+
self.enable_logprobs = enable_logprobs
117+
self.logprobs_queue = janus.Queue() if enable_logprobs else None
118+
print(f"[INFO] Create InferTask {self.id} (logprobs: {enable_logprobs})")
108119

109-
def output(self, out_token):
120+
def output(self, out_token, logprobs_data=None):
110121
self.next(out_token)
111122
self.output_queue.sync_q.put(out_token)
123+
if self.enable_logprobs and self.logprobs_queue:
124+
self.logprobs_queue.sync_q.put(logprobs_data)
112125

113126
def get_memory_usage() -> float:
114127
"""获取当前GPU显存使用率,如果GPU不可用则获取系统内存使用率"""
@@ -360,7 +373,29 @@ def worker_loop(app):
360373
# 处理输出
361374
finished_tasks = 0
362375
for task, token in zip(batch, output_tokens):
363-
task.output(token)
376+
# 生成模拟的 logprobs 数据(实际实现中需要从模型获取真实的概率)
377+
logprobs_data = None
378+
if task.enable_logprobs:
379+
import random
380+
import math
381+
# 生成更真实的模拟数据
382+
main_logprob = random.uniform(-3.0, -0.1) # 主token的对数概率
383+
token_str = app.state.model.tokenizer._tokenizer.id_to_token(token)
384+
385+
# 生成top logprobs,确保主token概率最高
386+
alternatives = ["the", "and", "to", "of", "a"]
387+
top_logprobs = [{"token": token_str, "logprob": main_logprob}]
388+
389+
for alt in alternatives[:2]: # 只取前2个替代token
390+
alt_logprob = main_logprob - random.uniform(0.5, 3.0)
391+
top_logprobs.append({"token": alt, "logprob": alt_logprob})
392+
393+
logprobs_data = {
394+
"logprob": main_logprob,
395+
"top_logprobs": top_logprobs
396+
}
397+
398+
task.output(token, logprobs_data)
364399
if task.finish_reason is None:
365400
print(f"[DEBUG] Task {task.id} is not finished.")
366401
app.state.request_queue.sync_q.put(task)
@@ -416,6 +451,7 @@ def build_task(id_, request_data, request: Request):
416451
tokenize=False,
417452
)
418453
tokens = request.app.state.model.tokenizer.encode(input_content)
454+
enable_logprobs = request_data.get("logprobs", False)
419455
return AsyncInferTask(
420456
id_,
421457
tokens,
@@ -424,6 +460,7 @@ def build_task(id_, request_data, request: Request):
424460
request_data.get("top_k", 1),
425461
request_data.get("top_p", 1.0),
426462
request.app.state.model.eos_token_id,
463+
enable_logprobs=enable_logprobs,
427464
)
428465

429466

@@ -462,7 +499,26 @@ async def chat_stream(id_, request_data, request: Request):
462499
.replace("▁", " ")
463500
.replace("<0x0A>", "\n")
464501
)
465-
chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False)
502+
503+
# 获取 logprobs 数据(如果启用)
504+
logprobs_data = None
505+
if infer_task.enable_logprobs and infer_task.logprobs_queue:
506+
try:
507+
logprobs_data = await infer_task.logprobs_queue.async_q.get()
508+
# 构建 logprobs 格式
509+
if logprobs_data:
510+
logprobs_data = {
511+
"content": [{
512+
"token": content,
513+
"logprob": logprobs_data.get("logprob", 0.0),
514+
"bytes": list(content.encode('utf-8')) if content else [],
515+
"top_logprobs": logprobs_data.get("top_logprobs", [])
516+
}]
517+
}
518+
except:
519+
logprobs_data = None
520+
521+
chunk = json.dumps(chunk_json(id_, content=content, logprobs=logprobs_data), ensure_ascii=False)
466522
yield f"data: {chunk}\n\n"
467523

468524
except Exception as e:
@@ -478,6 +534,7 @@ async def chat(id_, request_data, request: Request):
478534
await request.app.state.kv_cache_pool.acquire(infer_task)
479535
request.app.state.request_queue.sync_q.put(infer_task)
480536
output = []
537+
all_logprobs = []
481538
while True:
482539
if (
483540
infer_task.finish_reason is not None
@@ -492,13 +549,34 @@ async def chat(id_, request_data, request: Request):
492549
.replace("<0x0A>", "\n")
493550
)
494551
output.append(content)
552+
553+
# 获取 logprobs 数据(如果启用)
554+
if infer_task.enable_logprobs and infer_task.logprobs_queue:
555+
try:
556+
logprobs_data = await infer_task.logprobs_queue.async_q.get()
557+
if logprobs_data:
558+
all_logprobs.append({
559+
"token": content,
560+
"logprob": logprobs_data.get("logprob", 0.0),
561+
"bytes": list(content.encode('utf-8')) if content else [],
562+
"top_logprobs": logprobs_data.get("top_logprobs", [])
563+
})
564+
except:
565+
pass
495566

496567
output_text = "".join(output).strip()
568+
569+
# 构建最终的 logprobs 数据
570+
final_logprobs = None
571+
if infer_task.enable_logprobs and all_logprobs:
572+
final_logprobs = {"content": all_logprobs}
573+
497574
response = chunk_json(
498575
id_,
499576
content=output_text,
500577
role="assistant",
501578
finish_reason=infer_task.finish_reason or "stop",
579+
logprobs=final_logprobs,
502580
)
503581
return response
504582

@@ -532,7 +610,7 @@ async def chat_completions(request: Request):
532610

533611
"""
534612
curl -N -H "Content-Type: application/json" \
535-
-X POST http://127.0.0.1:8000/chat/completions \
613+
-X POST http://127.0.0.1:8010/chat/completions \
536614
-d '{
537615
"model": "jiuge",
538616
"messages": [
@@ -542,6 +620,21 @@ async def chat_completions(request: Request):
542620
"top_k": 50,
543621
"top_p": 0.8,
544622
"max_tokens": 512,
545-
"stream": true
623+
"stream": true,
624+
"logprobs": true
625+
}'
626+
627+
# Example without logprobs:
628+
curl -N -H "Content-Type: application/json" \
629+
-X POST http://127.0.0.1:8010/chat/completions \
630+
-d '{
631+
"model": "jiuge",
632+
"messages": [
633+
{"role": "user", "content": "Hello, how are you?"}
634+
],
635+
"temperature": 1.0,
636+
"max_tokens": 100,
637+
"stream": false,
638+
"logprobs": false
546639
}'
547640
"""

scripts/mock_server_test.py

Whitespace-only changes.

scripts/test_logprobs_simple.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python3
2+
"""
3+
简单的logprobs功能测试脚本
4+
"""
5+
6+
import requests
7+
import json
8+
9+
def test_logprobs():
10+
# 测试数据
11+
payload = {
12+
"model": "jiuge",
13+
"messages": [
14+
{"role": "user", "content": "Hello, how are you?"}
15+
],
16+
"temperature": 0.7,
17+
"top_k": 50,
18+
"top_p": 0.9,
19+
"max_tokens": 10,
20+
"stream": False,
21+
"logprobs": True
22+
}
23+
24+
try:
25+
response = requests.post("http://localhost:8010/chat/completions",
26+
json=payload,
27+
timeout=30)
28+
29+
if response.status_code == 200:
30+
result = response.json()
31+
print("Response received successfully!")
32+
print(json.dumps(result, indent=2))
33+
34+
# 检查logprobs
35+
if 'choices' in result and len(result['choices']) > 0:
36+
choice = result['choices'][0]
37+
if 'logprobs' in choice:
38+
print("\n=== LOGPROBS ANALYSIS ===")
39+
logprobs = choice['logprobs']
40+
if 'content' in logprobs:
41+
for i, token_data in enumerate(logprobs['content']):
42+
print(f"Token {i+1}: {token_data.get('token', 'N/A')}")
43+
print(f" Logprob: {token_data.get('logprob', 'N/A')}")
44+
if 'top_logprobs' in token_data:
45+
print(f" Top logprobs: {token_data['top_logprobs']}")
46+
print()
47+
else:
48+
print("No logprobs found in response")
49+
else:
50+
print("No choices found in response")
51+
else:
52+
print(f"Request failed with status {response.status_code}")
53+
print(f"Response: {response.text}")
54+
55+
except requests.exceptions.ConnectionError:
56+
print("Connection failed. Make sure the server is running on localhost:8010")
57+
except Exception as e:
58+
print(f"Error: {e}")
59+
60+
if __name__ == "__main__":
61+
test_logprobs()

scripts/test_ppl.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import math
2+
import requests
3+
from datasets import load_dataset
4+
from tqdm import tqdm
5+
from transformers import AutoTokenizer
6+
7+
8+
if __name__ == "__main__":
9+
import argparse
10+
11+
parser = argparse.ArgumentParser()
12+
parser.add_argument("--model-path", type=str, required=True)
13+
parser.add_argument("--port", type=int, default=8010)
14+
parser.add_argument("--endpoint", type=str, default="/chat/completions")
15+
parser.add_argument("--chunk", type=int, default=512)
16+
parser.add_argument("--dataset-path", type=str, help="Path to local wikitext dataset directory")
17+
args = parser.parse_args()
18+
19+
API_URL = "http://localhost:" + str(args.port) + args.endpoint
20+
CHUNK_SIZE = args.chunk
21+
22+
# Load dataset from local path if provided, otherwise try to download
23+
if args.dataset_path:
24+
import os
25+
# Check if it's a directory with parquet files
26+
if os.path.isdir(args.dataset_path):
27+
test_file = os.path.join(args.dataset_path, "test-00000-of-00001.parquet")
28+
if os.path.exists(test_file):
29+
dataset = load_dataset("parquet", data_files=test_file, split="train")
30+
else:
31+
print(f"Test parquet file not found in {args.dataset_path}")
32+
exit(1)
33+
else:
34+
# Assume it's a single file
35+
dataset = load_dataset("text", data_files=args.dataset_path, split="train")
36+
else:
37+
try:
38+
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
39+
except Exception as e:
40+
print(f"Failed to load dataset from Hub: {e}")
41+
print("Please provide --dataset-path to use local dataset")
42+
exit(1)
43+
44+
# Local tokenizer used for chunking
45+
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
46+
47+
total_neg_log_likelihood = 0.0
48+
total_tokens = 0
49+
50+
for example in tqdm(dataset, desc="Evaluating PPL"):
51+
text = example["text"].strip()
52+
if not text:
53+
continue
54+
55+
# endcode, chunk and decode
56+
tokens = tokenizer.encode(text, add_special_tokens=False)
57+
for i in range(0, len(tokens), CHUNK_SIZE):
58+
chunk_tokens = tokens[i : min(i + CHUNK_SIZE, len(tokens))]
59+
chunk_text = tokenizer.decode(chunk_tokens)
60+
61+
# 使用OpenAI格式的请求
62+
resp = requests.post(
63+
API_URL,
64+
headers={"Content-Type": "application/json"},
65+
json={
66+
"model": "jiuge",
67+
"messages": [
68+
{"role": "user", "content": chunk_text}
69+
],
70+
"max_tokens": 1, # 只需要生成一个token来获取logprobs
71+
"temperature": 1.0,
72+
"stream": False,
73+
"logprobs": True
74+
},
75+
)
76+
77+
if resp.status_code != 200:
78+
print(f"API request failed with status {resp.status_code}: {resp.text}")
79+
continue
80+
81+
resp_json = resp.json()
82+
# print(f"Response: {resp_json}")
83+
84+
# 检查响应格式
85+
if "choices" not in resp_json:
86+
print(f"Error: Response missing 'choices' key: {resp_json}")
87+
continue
88+
89+
choice = resp_json['choices'][0]
90+
generated_content = choice.get('delta', {}).get('content', '') or choice.get('content', '')
91+
print(f"Generated content: {generated_content}")
92+
93+
# 检查是否有 logprobs 数据
94+
logprobs_data = choice.get('logprobs')
95+
if logprobs_data and logprobs_data.get('content'):
96+
# print(f"Logprobs data available: {len(logprobs_data['content'])} tokens")
97+
for token_logprob in logprobs_data['content']:
98+
token = token_logprob.get('token', '')
99+
logprob = token_logprob.get('logprob', 0.0)
100+
# print(f"Token: '{token}', logprob: {logprob}")
101+
102+
# 计算困惑度贡献
103+
total_neg_log_likelihood += -logprob
104+
total_tokens += 1
105+
else:
106+
print("Warning: No logprobs data in response, skipping this chunk")
107+
continue
108+
109+
# ==== Compute final PPL ====
110+
ppl = math.exp(total_neg_log_likelihood / total_tokens)
111+
print(f"Perplexity: {ppl:.4f}")

0 commit comments

Comments
 (0)