1515
1616"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""
1717
18- import copy
1918import os
2019import queue
2120import time
4342from vllm .v1 .core import kv_cache_utils
4443
4544from open_instruct import logger_utils
46- from open_instruct .queue_types import GenerationResult , RequestInfo , TokenStatistics
45+ from open_instruct .queue_types import GenerationResult , PromptRequest , RequestInfo , TokenStatistics
4746from open_instruct .tool_utils .tool_vllm import MaxCallsExceededTool , Tool
4847from open_instruct .utils import ray_get_with_progress
4948
@@ -93,7 +92,7 @@ def _handle_output(output, tools, tracking, sampling_params, max_tool_calls, exe
9392 if not tools :
9493 return output
9594
96- assert len (output .outputs ) <= 1 # In tool mode, sampling_params.n == 1
95+ assert len (output .outputs ) <= 1 , f" { len ( output . outputs ) = } " # In tool mode, sampling_params.n == 1
9796 o = output .outputs [0 ]
9897
9998 # Update concatenated outputs
@@ -203,7 +202,6 @@ def _process_outputs_with_tools(
203202def _finalize_outputs (outputs , tracking , dataset_index , tools , token_statistics = None , start_time = None ):
204203 """Prepare final outputs based on whether tools were used."""
205204 if not tools :
206- outputs .sort (key = lambda x : int (x .request_id .split ("_" )[- 1 ]))
207205 return _process_outputs (
208206 outputs , dataset_index = dataset_index , token_statistics = token_statistics , start_time = start_time
209207 )
@@ -223,14 +221,14 @@ def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=
223221 # Merge n completions into the same outputs
224222 merged_outputs = {}
225223 for req_id in tracking ["concat_outputs" ]:
226- real_req_id , _ = req_id .split ("-" )
224+ real_req_id = "_" . join ( req_id .split ("_" )[: - 1 ] )
227225 if real_req_id not in merged_outputs :
228226 merged_outputs [real_req_id ] = tracking ["concat_outputs" ][req_id ]
229227 else :
230228 merged_outputs [real_req_id ].outputs .append (tracking ["concat_outputs" ][req_id ].outputs [0 ])
231229
232230 final_outputs = sorted (
233- merged_outputs .values (), key = lambda x : (int (x .request_id .split ("- " )[0 ]), int (x .request_id .split ("- " )[1 ]))
231+ merged_outputs .values (), key = lambda x : (int (x .request_id .split ("_ " )[1 ]), int (x .request_id .split ("_ " )[2 ]))
234232 )
235233
236234 return _process_outputs_with_tools (
@@ -317,6 +315,32 @@ def init_process_group(
317315 return pg
318316
319317
318+ def add_request (request : PromptRequest , llm_engine : vllm .LLMEngine , tools , request_metadata : dict ):
319+ """Add a request to the LLM engine."""
320+ prefix = "eval" if request .is_eval else "train"
321+
322+ for batch_idx , prompt in enumerate (request .prompts ):
323+ request_id = f"{ prefix } _{ request .training_step } _{ batch_idx } "
324+ sampling_params = request .generation_config .clone ()
325+ sampling_params .n = 1 # Use n=1 for tool processing
326+ request_metadata [request_id ] = {
327+ "is_eval" : request .is_eval ,
328+ "dataset_index" : request .dataset_index [batch_idx ],
329+ "training_step" : request .training_step ,
330+ "sampling_params" : sampling_params ,
331+ "prompt_tokens" : len (prompt ),
332+ "start_time" : time .perf_counter (),
333+ }
334+
335+ tokens_prompt = vllm .TokensPrompt (prompt_token_ids = prompt , cache_salt = request_id )
336+
337+ for j in range (request .generation_config .n ):
338+ sub_sampling_params = sampling_params .clone () # Already has n=1
339+ if request .generation_config .seed is not None :
340+ sub_sampling_params .seed = request .generation_config .seed + j
341+ llm_engine .add_request (f"{ request_id } _{ j } " , tokens_prompt , sub_sampling_params )
342+
343+
320344class LLMRayActor :
321345 """Ray actor for LLM generation with optional tool support."""
322346
@@ -384,6 +408,15 @@ def _should_stop(self) -> bool:
384408 ray .cancel (should_stop_ref )
385409 return self ._should_stop_value
386410
411+ def _insert_result_to_queue (self , result , is_eval : bool ):
412+ """Insert result into the appropriate queue with error handling."""
413+ try :
414+ results_queue = self .eval_results_queue if is_eval else self .results_queue
415+ results_queue .put (result , timeout = 10 )
416+ except queue .Full :
417+ queue_name = "eval" if is_eval else "train"
418+ self .logger .warning (f"{ queue_name } results queue is full, discarding result." )
419+
387420 def process_from_queue (self , timeout : float = 60.0 ):
388421 """Run generation loop using LLMEngine directly, with optional tool support.
389422
@@ -401,37 +434,20 @@ def process_from_queue(self, timeout: float = 60.0):
401434
402435 result = self ._process_request (request )
403436
404- try :
405- if request .is_eval :
406- self .eval_results_queue .put (result , timeout = 10 )
407- else :
408- self .results_queue .put (result , timeout = 10 )
409- return 1 # Successfully processed one request
410- except queue .Full :
411- self .logger .warning ("Results queue is full, discarding result." )
412- return 0
437+ self ._insert_result_to_queue (result , is_eval = request .is_eval )
438+ return 1
413439
414440 def _process_request (self , request ):
415441 """Unified processing for both tool and non-tool generation."""
416- prompts = request .prompts
417- sampling_params = request .generation_config
418- start_time = request .start_time
419442
420- self .logger .info (f"[LLMRayActor] Processing request with { len (prompts )} prompts, tools={ bool (self .tools )} " )
443+ self .logger .info (
444+ f"[LLMRayActor] Processing request with { len (request .prompts )} prompts, tools={ bool (self .tools )} "
445+ )
421446
422- if self .tools :
423- # Need n=1 for individual tool tracking
424- sampling_params = copy .deepcopy (sampling_params )
425- original_n = request .generation_config .n
426- sampling_params .n = 1
427- tracking = _init_tool_tracking ()
428- tokenizer = self .llm_engine .tokenizer
429- else :
430- original_n = 1
431- tracking = None
432- tokenizer = None
447+ tracking = _init_tool_tracking () if self .tools else None
448+ tokenizer = self .llm_engine .tokenizer
433449
434- self . _add_initial_requests ( prompts , sampling_params , original_n , request . training_step )
450+ add_request ( request , self . llm_engine , self . tools , request_metadata = self . request_metadata )
435451
436452 outputs = []
437453 iteration = 0
@@ -441,18 +457,19 @@ def _process_request(self, request):
441457
442458 # Poll tool futures first (matching ToolUseLLM order)
443459 if tracking and tracking .get ("pending_tool_futures" ):
444- self ._poll_tool_futures (tracking , sampling_params , tokenizer )
460+ outputs . extend ( self ._poll_tool_futures (tracking , tokenizer ) )
445461
446462 # Process engine steps - ONLY if there are unfinished requests (matching ToolUseLLM)
447463 if self .llm_engine .has_unfinished_requests ():
448- step_outputs = list ( self .llm_engine .step ())
464+ step_outputs = [ o for o in self .llm_engine .step () if o . finished ]
449465 for output in step_outputs :
450- if output .finished :
451- result = _handle_output (
452- output , self .tools , tracking , sampling_params , self .max_tool_calls , self .executor
453- )
454- if result is not None :
455- outputs .append (result )
466+ self .logger .info (f"{ len (output .outputs )= } " )
467+ result = _handle_output (
468+ output , self .tools , tracking , request .generation_config , self .max_tool_calls , self .executor
469+ )
470+ # Result is None when we do more tool processing.
471+ if result is not None :
472+ outputs .append (result )
456473
457474 # Check termination condition (matching ToolUseLLM exactly)
458475 pending_count = len (tracking ["pending_tool_futures" ]) if tracking else 0
@@ -465,23 +482,40 @@ def _process_request(self, request):
465482 total_generation_tokens = 0
466483 earliest_start_time = float ("inf" )
467484
485+ # Now, we combine outputs:
486+ combined_outputs = defaultdict (list )
468487 for output in outputs :
469- request_id = output .request_id
470- if request_id in self .request_metadata :
471- metadata = self .request_metadata [request_id ]
472- total_prompt_tokens += metadata ["prompt_tokens" ]
473- earliest_start_time = min (earliest_start_time , metadata ["start_time" ])
474-
488+ # Remove the sub_idx.
489+ request_id = "_" .join (output .request_id .split ("_" )[:- 1 ])
490+ combined_outputs [request_id ].append (output )
491+ # Preserve original order from request.dataset_index
492+ prefix = "eval" if request .is_eval else "train"
493+ # request_id is batch_num _ training_step _ within_batch_idx _ repetition_idx.
494+ # we order by within_batch_idx.
495+ ordered_ids = [f"{ prefix } _{ request .training_step } _{ batch_idx } " for batch_idx in range (len (request .prompts ))]
496+ final_outputs = []
497+ for request_id in ordered_ids :
498+ outs = combined_outputs [request_id ]
499+ assert len (outs ) == request .generation_config .n , f"{ len (outs )= } != { request .generation_config .n = } "
500+ final_outputs .append (
501+ vllm .RequestOutput (
502+ request_id = request_id ,
503+ prompt = outs [0 ].prompt ,
504+ prompt_token_ids = outs [0 ].prompt_token_ids ,
505+ prompt_logprobs = outs [0 ].prompt_logprobs ,
506+ outputs = [completion for out in outs for completion in out .outputs ],
507+ finished = outs [0 ].finished ,
508+ )
509+ )
510+ metadata = self .request_metadata .pop (request_id )
511+ total_prompt_tokens += metadata ["prompt_tokens" ]
512+ earliest_start_time = min (earliest_start_time , metadata ["start_time" ])
513+ for output in outs :
475514 for completion in output .outputs :
476515 total_generation_tokens += len (completion .token_ids )
477-
478516 generation_time = end_time - earliest_start_time
479-
480- for output in outputs :
481- self .request_metadata .pop (output .request_id , None )
482-
483517 result = _finalize_outputs (
484- outputs ,
518+ final_outputs ,
485519 tracking ,
486520 request .dataset_index ,
487521 self .tools ,
@@ -490,33 +524,17 @@ def _process_request(self, request):
490524 num_response_tokens = total_generation_tokens ,
491525 generation_time = generation_time ,
492526 ),
493- start_time = start_time ,
527+ start_time = request . start_time ,
494528 )
495529 return result
496530
497- def _add_initial_requests (self , prompts , sampling_params , n_samples , training_step ):
498- """Add initial requests to the engine."""
499- for i , prompt in enumerate (prompts ):
500- if self .tools :
501- # Create individual requests for each sample when using tools
502- for j in range (n_samples ):
503- request_id = f"{ training_step } _{ i } -{ j } "
504- self .request_metadata [request_id ] = {"start_time" : time .time (), "prompt_tokens" : len (prompt )}
505- tokens_prompt = vllm .TokensPrompt (prompt_token_ids = prompt , cache_salt = f"{ training_step } _{ i } " )
506- self .llm_engine .add_request (request_id , tokens_prompt , sampling_params )
507- else :
508- # Standard request format for non-tool mode
509- request_id = f"batch_{ training_step } _{ i } "
510- self .request_metadata [request_id ] = {"start_time" : time .time (), "prompt_tokens" : len (prompt )}
511- tokens_prompt = vllm .TokensPrompt (prompt_token_ids = prompt , cache_salt = request_id )
512- self .llm_engine .add_request (request_id , tokens_prompt , sampling_params )
513-
514- def _poll_tool_futures (self , tracking , sampling_params , tokenizer ):
531+ def _poll_tool_futures (self , tracking , tokenizer ):
515532 """Poll and handle completed tool executions."""
516533 if not self .tools or not tracking ["pending_tool_futures" ]:
517- return
534+ return []
518535
519536 dict_keys_to_delete = []
537+ completed_outputs = []
520538
521539 for req_id , (future , last_o , last_output ) in tracking ["pending_tool_futures" ].items ():
522540 if not future .done ():
@@ -525,6 +543,11 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
525543 # Tool future is done, process it
526544 tool_result = future .result () # Get the tool result
527545
546+ # Get sampling params from request metadata for this request
547+ # Extract the base request ID by removing the sub-request suffix
548+ base_req_id = "_" .join (req_id .split ("_" )[:- 1 ])
549+ sampling_params = self .request_metadata [base_req_id ]["sampling_params" ]
550+
528551 last_prompt_token_ids = last_output .prompt_token_ids
529552 last_token_ids = last_o .token_ids
530553 tool_output_token_ids = tokenizer .encode (
@@ -559,7 +582,7 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
559582 can_make_new_request = can_make_new_request and new_sample_tokens > 0
560583
561584 if can_make_new_request :
562- new_sampling_params = copy . deepcopy ( sampling_params )
585+ new_sampling_params = sampling_params . clone ( )
563586 new_sampling_params .max_tokens = new_sample_tokens
564587
565588 try :
@@ -569,12 +592,16 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
569592 except Exception as e :
570593 # Match original ToolUseLLM behavior - just log and continue
571594 self .logger .error (f"[_poll_tool_futures] Error adding request { req_id } : { e } " )
595+ else :
596+ # If we can't make a new request, this tool execution is complete
597+ completed_outputs .append (tracking ["concat_outputs" ][req_id ])
572598
573599 dict_keys_to_delete .append (req_id )
574600
575601 for req_id in dict_keys_to_delete :
576- if req_id in tracking ["pending_tool_futures" ]:
577- del tracking ["pending_tool_futures" ][req_id ]
602+ tracking ["pending_tool_futures" ].pop (req_id , None )
603+
604+ return completed_outputs
578605
579606 def init_process_group (
580607 self ,
0 commit comments