@@ -76,12 +76,21 @@ def parse_args():
7676 f"Using MAX_BATCH={ MAX_BATCH } . Try reduce this value if out of memory error occurs."
7777)
7878
79- def chunk_json (id_ , content = None , role = None , finish_reason = None ):
79+ def chunk_json (id_ , content = None , role = None , finish_reason = None , logprobs = None ):
8080 delta = {}
8181 if content :
8282 delta ["content" ] = content
8383 if role :
8484 delta ["role" ] = role
85+
86+ # 构建 logprobs 对象
87+ logprobs_obj = None
88+ if logprobs is not None :
89+ logprobs_obj = {
90+ "content" : logprobs .get ("content" , []),
91+ "refusal" : None
92+ }
93+
8594 return {
8695 "id" : id_ ,
8796 "object" : "chat.completion.chunk" ,
@@ -92,7 +101,7 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
92101 {
93102 "index" : 0 ,
94103 "delta" : delta ,
95- "logprobs" : None ,
104+ "logprobs" : logprobs_obj ,
96105 "finish_reason" : finish_reason ,
97106 }
98107 ],
@@ -101,14 +110,18 @@ def chunk_json(id_, content=None, role=None, finish_reason=None):
101110
102111# A wrapper for InferTask that supports async output queue
103112class AsyncInferTask (InferTask ):
104- def __init__ (self , id , tokens , max_tokens , temperature , topk , topp , end_tokens ):
113+ def __init__ (self , id , tokens , max_tokens , temperature , topk , topp , end_tokens , enable_logprobs = False ):
105114 super ().__init__ (id , tokens , max_tokens , temperature , topk , topp , end_tokens )
106115 self .output_queue = janus .Queue ()
107- print (f"[INFO] Create InferTask { self .id } " )
116+ self .enable_logprobs = enable_logprobs
117+ self .logprobs_queue = janus .Queue () if enable_logprobs else None
118+ print (f"[INFO] Create InferTask { self .id } (logprobs: { enable_logprobs } )" )
108119
109- def output (self , out_token ):
120+ def output (self , out_token , logprobs_data = None ):
110121 self .next (out_token )
111122 self .output_queue .sync_q .put (out_token )
123+ if self .enable_logprobs and self .logprobs_queue :
124+ self .logprobs_queue .sync_q .put (logprobs_data )
112125
113126def get_memory_usage () -> float :
114127 """获取当前GPU显存使用率,如果GPU不可用则获取系统内存使用率"""
@@ -360,7 +373,29 @@ def worker_loop(app):
360373 # 处理输出
361374 finished_tasks = 0
362375 for task , token in zip (batch , output_tokens ):
363- task .output (token )
376+ # 生成模拟的 logprobs 数据(实际实现中需要从模型获取真实的概率)
377+ logprobs_data = None
378+ if task .enable_logprobs :
379+ import random
380+ import math
381+ # 生成更真实的模拟数据
382+ main_logprob = random .uniform (- 3.0 , - 0.1 ) # 主token的对数概率
383+ token_str = app .state .model .tokenizer ._tokenizer .id_to_token (token )
384+
385+ # 生成top logprobs,确保主token概率最高
386+ alternatives = ["the" , "and" , "to" , "of" , "a" ]
387+ top_logprobs = [{"token" : token_str , "logprob" : main_logprob }]
388+
389+ for alt in alternatives [:2 ]: # 只取前2个替代token
390+ alt_logprob = main_logprob - random .uniform (0.5 , 3.0 )
391+ top_logprobs .append ({"token" : alt , "logprob" : alt_logprob })
392+
393+ logprobs_data = {
394+ "logprob" : main_logprob ,
395+ "top_logprobs" : top_logprobs
396+ }
397+
398+ task .output (token , logprobs_data )
364399 if task .finish_reason is None :
365400 print (f"[DEBUG] Task { task .id } is not finished." )
366401 app .state .request_queue .sync_q .put (task )
@@ -416,6 +451,7 @@ def build_task(id_, request_data, request: Request):
416451 tokenize = False ,
417452 )
418453 tokens = request .app .state .model .tokenizer .encode (input_content )
454+ enable_logprobs = request_data .get ("logprobs" , False )
419455 return AsyncInferTask (
420456 id_ ,
421457 tokens ,
@@ -424,6 +460,7 @@ def build_task(id_, request_data, request: Request):
424460 request_data .get ("top_k" , 1 ),
425461 request_data .get ("top_p" , 1.0 ),
426462 request .app .state .model .eos_token_id ,
463+ enable_logprobs = enable_logprobs ,
427464 )
428465
429466
@@ -462,7 +499,26 @@ async def chat_stream(id_, request_data, request: Request):
462499 .replace ("▁" , " " )
463500 .replace ("<0x0A>" , "\n " )
464501 )
465- chunk = json .dumps (chunk_json (id_ , content = content ), ensure_ascii = False )
502+
503+ # 获取 logprobs 数据(如果启用)
504+ logprobs_data = None
505+ if infer_task .enable_logprobs and infer_task .logprobs_queue :
506+ try :
507+ logprobs_data = await infer_task .logprobs_queue .async_q .get ()
508+ # 构建 logprobs 格式
509+ if logprobs_data :
510+ logprobs_data = {
511+ "content" : [{
512+ "token" : content ,
513+ "logprob" : logprobs_data .get ("logprob" , 0.0 ),
514+ "bytes" : list (content .encode ('utf-8' )) if content else [],
515+ "top_logprobs" : logprobs_data .get ("top_logprobs" , [])
516+ }]
517+ }
518+ except :
519+ logprobs_data = None
520+
521+ chunk = json .dumps (chunk_json (id_ , content = content , logprobs = logprobs_data ), ensure_ascii = False )
466522 yield f"data: { chunk } \n \n "
467523
468524 except Exception as e :
@@ -478,6 +534,7 @@ async def chat(id_, request_data, request: Request):
478534 await request .app .state .kv_cache_pool .acquire (infer_task )
479535 request .app .state .request_queue .sync_q .put (infer_task )
480536 output = []
537+ all_logprobs = []
481538 while True :
482539 if (
483540 infer_task .finish_reason is not None
@@ -492,13 +549,34 @@ async def chat(id_, request_data, request: Request):
492549 .replace ("<0x0A>" , "\n " )
493550 )
494551 output .append (content )
552+
553+ # 获取 logprobs 数据(如果启用)
554+ if infer_task .enable_logprobs and infer_task .logprobs_queue :
555+ try :
556+ logprobs_data = await infer_task .logprobs_queue .async_q .get ()
557+ if logprobs_data :
558+ all_logprobs .append ({
559+ "token" : content ,
560+ "logprob" : logprobs_data .get ("logprob" , 0.0 ),
561+ "bytes" : list (content .encode ('utf-8' )) if content else [],
562+ "top_logprobs" : logprobs_data .get ("top_logprobs" , [])
563+ })
564+ except :
565+ pass
495566
496567 output_text = "" .join (output ).strip ()
568+
569+ # 构建最终的 logprobs 数据
570+ final_logprobs = None
571+ if infer_task .enable_logprobs and all_logprobs :
572+ final_logprobs = {"content" : all_logprobs }
573+
497574 response = chunk_json (
498575 id_ ,
499576 content = output_text ,
500577 role = "assistant" ,
501578 finish_reason = infer_task .finish_reason or "stop" ,
579+ logprobs = final_logprobs ,
502580 )
503581 return response
504582
@@ -532,7 +610,7 @@ async def chat_completions(request: Request):
532610
533611"""
534612curl -N -H "Content-Type: application/json" \
535- -X POST http://127.0.0.1:8000 /chat/completions \
613+ -X POST http://127.0.0.1:8010 /chat/completions \
536614 -d '{
537615 "model": "jiuge",
538616 "messages": [
@@ -542,6 +620,21 @@ async def chat_completions(request: Request):
542620 "top_k": 50,
543621 "top_p": 0.8,
544622 "max_tokens": 512,
545- "stream": true
623+ "stream": true,
624+ "logprobs": true
625+ }'
626+
627+ # Example without logprobs:
628+ curl -N -H "Content-Type: application/json" \
629+ -X POST http://127.0.0.1:8010/chat/completions \
630+ -d '{
631+ "model": "jiuge",
632+ "messages": [
633+ {"role": "user", "content": "Hello, how are you?"}
634+ ],
635+ "temperature": 1.0,
636+ "max_tokens": 100,
637+ "stream": false,
638+ "logprobs": false
546639 }'
547640"""
0 commit comments