Skip to content

Commit 91e0310

Browse files
Manually duplicate both tool and non-tool requests. (#978)
* uses add_request * ran linter * Clean up * Fixed bug * Added duplication * Added prompt_tokens to metadata. * Added missing key to metadata * Fixed bug where we weren't returning properly. * Fix script * Added logging * fix bug * use clone for SamplingParams * Fixes to duplication * Removed logging. * Cleaned up PR. * Clean PR * Removed whitespace * Cleaned up PR * Added comment for cleaner PR. * Cleaning up PR * Revert "load pretokenized user query (v0) (#965)" This reverts commit fa7e608. * Bug fix. * Fixed issue where we weren't setting params right in tools. * Updated descriptions. * Fix ordering. * Updated tool script with description. * Fixed use of wrong vllm.SamplingParams. * Now, tool use should run. * Reapply "load pretokenized user query (v0) (#965)" This reverts commit 341a77b. * minor clean up.
1 parent 674c706 commit 91e0310

File tree

6 files changed

+108
-76
lines changed

6 files changed

+108
-76
lines changed

open_instruct/grpo_fast.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2919,6 +2919,8 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, num_eval_sa
29192919
actor_manager,
29202920
checkpoint_state,
29212921
)
2922+
except Exception as e:
2923+
logger.error(f"Error in run_training: {e}", exc_info=True)
29222924
finally:
29232925
cleanup_training_resources(
29242926
stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager

open_instruct/vllm_utils3.py

Lines changed: 101 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
"""This file is copied from https://github.com/OpenRLHF/OpenRLHF"""
1717

18-
import copy
1918
import os
2019
import queue
2120
import time
@@ -43,7 +42,7 @@
4342
from vllm.v1.core import kv_cache_utils
4443

4544
from open_instruct import logger_utils
46-
from open_instruct.queue_types import GenerationResult, RequestInfo, TokenStatistics
45+
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics
4746
from open_instruct.tool_utils.tool_vllm import MaxCallsExceededTool, Tool
4847
from open_instruct.utils import ray_get_with_progress
4948

@@ -93,7 +92,7 @@ def _handle_output(output, tools, tracking, sampling_params, max_tool_calls, exe
9392
if not tools:
9493
return output
9594

96-
assert len(output.outputs) <= 1 # In tool mode, sampling_params.n == 1
95+
assert len(output.outputs) <= 1, f"{len(output.outputs)=}" # In tool mode, sampling_params.n == 1
9796
o = output.outputs[0]
9897

9998
# Update concatenated outputs
@@ -203,7 +202,6 @@ def _process_outputs_with_tools(
203202
def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=None, start_time=None):
204203
"""Prepare final outputs based on whether tools were used."""
205204
if not tools:
206-
outputs.sort(key=lambda x: int(x.request_id.split("_")[-1]))
207205
return _process_outputs(
208206
outputs, dataset_index=dataset_index, token_statistics=token_statistics, start_time=start_time
209207
)
@@ -223,14 +221,14 @@ def _finalize_outputs(outputs, tracking, dataset_index, tools, token_statistics=
223221
# Merge n completions into the same outputs
224222
merged_outputs = {}
225223
for req_id in tracking["concat_outputs"]:
226-
real_req_id, _ = req_id.split("-")
224+
real_req_id = "_".join(req_id.split("_")[:-1])
227225
if real_req_id not in merged_outputs:
228226
merged_outputs[real_req_id] = tracking["concat_outputs"][req_id]
229227
else:
230228
merged_outputs[real_req_id].outputs.append(tracking["concat_outputs"][req_id].outputs[0])
231229

232230
final_outputs = sorted(
233-
merged_outputs.values(), key=lambda x: (int(x.request_id.split("-")[0]), int(x.request_id.split("-")[1]))
231+
merged_outputs.values(), key=lambda x: (int(x.request_id.split("_")[1]), int(x.request_id.split("_")[2]))
234232
)
235233

236234
return _process_outputs_with_tools(
@@ -317,6 +315,32 @@ def init_process_group(
317315
return pg
318316

319317

318+
def add_request(request: PromptRequest, llm_engine: vllm.LLMEngine, tools, request_metadata: dict):
319+
"""Add a request to the LLM engine."""
320+
prefix = "eval" if request.is_eval else "train"
321+
322+
for batch_idx, prompt in enumerate(request.prompts):
323+
request_id = f"{prefix}_{request.training_step}_{batch_idx}"
324+
sampling_params = request.generation_config.clone()
325+
sampling_params.n = 1 # Use n=1 for tool processing
326+
request_metadata[request_id] = {
327+
"is_eval": request.is_eval,
328+
"dataset_index": request.dataset_index[batch_idx],
329+
"training_step": request.training_step,
330+
"sampling_params": sampling_params,
331+
"prompt_tokens": len(prompt),
332+
"start_time": time.perf_counter(),
333+
}
334+
335+
tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id)
336+
337+
for j in range(request.generation_config.n):
338+
sub_sampling_params = sampling_params.clone() # Already has n=1
339+
if request.generation_config.seed is not None:
340+
sub_sampling_params.seed = request.generation_config.seed + j
341+
llm_engine.add_request(f"{request_id}_{j}", tokens_prompt, sub_sampling_params)
342+
343+
320344
class LLMRayActor:
321345
"""Ray actor for LLM generation with optional tool support."""
322346

@@ -384,6 +408,15 @@ def _should_stop(self) -> bool:
384408
ray.cancel(should_stop_ref)
385409
return self._should_stop_value
386410

411+
def _insert_result_to_queue(self, result, is_eval: bool):
412+
"""Insert result into the appropriate queue with error handling."""
413+
try:
414+
results_queue = self.eval_results_queue if is_eval else self.results_queue
415+
results_queue.put(result, timeout=10)
416+
except queue.Full:
417+
queue_name = "eval" if is_eval else "train"
418+
self.logger.warning(f"{queue_name} results queue is full, discarding result.")
419+
387420
def process_from_queue(self, timeout: float = 60.0):
388421
"""Run generation loop using LLMEngine directly, with optional tool support.
389422
@@ -401,37 +434,20 @@ def process_from_queue(self, timeout: float = 60.0):
401434

402435
result = self._process_request(request)
403436

404-
try:
405-
if request.is_eval:
406-
self.eval_results_queue.put(result, timeout=10)
407-
else:
408-
self.results_queue.put(result, timeout=10)
409-
return 1 # Successfully processed one request
410-
except queue.Full:
411-
self.logger.warning("Results queue is full, discarding result.")
412-
return 0
437+
self._insert_result_to_queue(result, is_eval=request.is_eval)
438+
return 1
413439

414440
def _process_request(self, request):
415441
"""Unified processing for both tool and non-tool generation."""
416-
prompts = request.prompts
417-
sampling_params = request.generation_config
418-
start_time = request.start_time
419442

420-
self.logger.info(f"[LLMRayActor] Processing request with {len(prompts)} prompts, tools={bool(self.tools)}")
443+
self.logger.info(
444+
f"[LLMRayActor] Processing request with {len(request.prompts)} prompts, tools={bool(self.tools)}"
445+
)
421446

422-
if self.tools:
423-
# Need n=1 for individual tool tracking
424-
sampling_params = copy.deepcopy(sampling_params)
425-
original_n = request.generation_config.n
426-
sampling_params.n = 1
427-
tracking = _init_tool_tracking()
428-
tokenizer = self.llm_engine.tokenizer
429-
else:
430-
original_n = 1
431-
tracking = None
432-
tokenizer = None
447+
tracking = _init_tool_tracking() if self.tools else None
448+
tokenizer = self.llm_engine.tokenizer
433449

434-
self._add_initial_requests(prompts, sampling_params, original_n, request.training_step)
450+
add_request(request, self.llm_engine, self.tools, request_metadata=self.request_metadata)
435451

436452
outputs = []
437453
iteration = 0
@@ -441,18 +457,19 @@ def _process_request(self, request):
441457

442458
# Poll tool futures first (matching ToolUseLLM order)
443459
if tracking and tracking.get("pending_tool_futures"):
444-
self._poll_tool_futures(tracking, sampling_params, tokenizer)
460+
outputs.extend(self._poll_tool_futures(tracking, tokenizer))
445461

446462
# Process engine steps - ONLY if there are unfinished requests (matching ToolUseLLM)
447463
if self.llm_engine.has_unfinished_requests():
448-
step_outputs = list(self.llm_engine.step())
464+
step_outputs = [o for o in self.llm_engine.step() if o.finished]
449465
for output in step_outputs:
450-
if output.finished:
451-
result = _handle_output(
452-
output, self.tools, tracking, sampling_params, self.max_tool_calls, self.executor
453-
)
454-
if result is not None:
455-
outputs.append(result)
466+
self.logger.info(f"{len(output.outputs)=}")
467+
result = _handle_output(
468+
output, self.tools, tracking, request.generation_config, self.max_tool_calls, self.executor
469+
)
470+
# Result is None when we do more tool processing.
471+
if result is not None:
472+
outputs.append(result)
456473

457474
# Check termination condition (matching ToolUseLLM exactly)
458475
pending_count = len(tracking["pending_tool_futures"]) if tracking else 0
@@ -465,23 +482,40 @@ def _process_request(self, request):
465482
total_generation_tokens = 0
466483
earliest_start_time = float("inf")
467484

485+
# Now, we combine outputs:
486+
combined_outputs = defaultdict(list)
468487
for output in outputs:
469-
request_id = output.request_id
470-
if request_id in self.request_metadata:
471-
metadata = self.request_metadata[request_id]
472-
total_prompt_tokens += metadata["prompt_tokens"]
473-
earliest_start_time = min(earliest_start_time, metadata["start_time"])
474-
488+
# Remove the sub_idx.
489+
request_id = "_".join(output.request_id.split("_")[:-1])
490+
combined_outputs[request_id].append(output)
491+
# Preserve original order from request.dataset_index
492+
prefix = "eval" if request.is_eval else "train"
493+
# request_id is batch_num _ training_step _ within_batch_idx _ repetition_idx.
494+
# we order by within_batch_idx.
495+
ordered_ids = [f"{prefix}_{request.training_step}_{batch_idx}" for batch_idx in range(len(request.prompts))]
496+
final_outputs = []
497+
for request_id in ordered_ids:
498+
outs = combined_outputs[request_id]
499+
assert len(outs) == request.generation_config.n, f"{len(outs)=} != {request.generation_config.n=}"
500+
final_outputs.append(
501+
vllm.RequestOutput(
502+
request_id=request_id,
503+
prompt=outs[0].prompt,
504+
prompt_token_ids=outs[0].prompt_token_ids,
505+
prompt_logprobs=outs[0].prompt_logprobs,
506+
outputs=[completion for out in outs for completion in out.outputs],
507+
finished=outs[0].finished,
508+
)
509+
)
510+
metadata = self.request_metadata.pop(request_id)
511+
total_prompt_tokens += metadata["prompt_tokens"]
512+
earliest_start_time = min(earliest_start_time, metadata["start_time"])
513+
for output in outs:
475514
for completion in output.outputs:
476515
total_generation_tokens += len(completion.token_ids)
477-
478516
generation_time = end_time - earliest_start_time
479-
480-
for output in outputs:
481-
self.request_metadata.pop(output.request_id, None)
482-
483517
result = _finalize_outputs(
484-
outputs,
518+
final_outputs,
485519
tracking,
486520
request.dataset_index,
487521
self.tools,
@@ -490,33 +524,17 @@ def _process_request(self, request):
490524
num_response_tokens=total_generation_tokens,
491525
generation_time=generation_time,
492526
),
493-
start_time=start_time,
527+
start_time=request.start_time,
494528
)
495529
return result
496530

497-
def _add_initial_requests(self, prompts, sampling_params, n_samples, training_step):
498-
"""Add initial requests to the engine."""
499-
for i, prompt in enumerate(prompts):
500-
if self.tools:
501-
# Create individual requests for each sample when using tools
502-
for j in range(n_samples):
503-
request_id = f"{training_step}_{i}-{j}"
504-
self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)}
505-
tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=f"{training_step}_{i}")
506-
self.llm_engine.add_request(request_id, tokens_prompt, sampling_params)
507-
else:
508-
# Standard request format for non-tool mode
509-
request_id = f"batch_{training_step}_{i}"
510-
self.request_metadata[request_id] = {"start_time": time.time(), "prompt_tokens": len(prompt)}
511-
tokens_prompt = vllm.TokensPrompt(prompt_token_ids=prompt, cache_salt=request_id)
512-
self.llm_engine.add_request(request_id, tokens_prompt, sampling_params)
513-
514-
def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
531+
def _poll_tool_futures(self, tracking, tokenizer):
515532
"""Poll and handle completed tool executions."""
516533
if not self.tools or not tracking["pending_tool_futures"]:
517-
return
534+
return []
518535

519536
dict_keys_to_delete = []
537+
completed_outputs = []
520538

521539
for req_id, (future, last_o, last_output) in tracking["pending_tool_futures"].items():
522540
if not future.done():
@@ -525,6 +543,11 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
525543
# Tool future is done, process it
526544
tool_result = future.result() # Get the tool result
527545

546+
# Get sampling params from request metadata for this request
547+
# Extract the base request ID by removing the sub-request suffix
548+
base_req_id = "_".join(req_id.split("_")[:-1])
549+
sampling_params = self.request_metadata[base_req_id]["sampling_params"]
550+
528551
last_prompt_token_ids = last_output.prompt_token_ids
529552
last_token_ids = last_o.token_ids
530553
tool_output_token_ids = tokenizer.encode(
@@ -559,7 +582,7 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
559582
can_make_new_request = can_make_new_request and new_sample_tokens > 0
560583

561584
if can_make_new_request:
562-
new_sampling_params = copy.deepcopy(sampling_params)
585+
new_sampling_params = sampling_params.clone()
563586
new_sampling_params.max_tokens = new_sample_tokens
564587

565588
try:
@@ -569,12 +592,16 @@ def _poll_tool_futures(self, tracking, sampling_params, tokenizer):
569592
except Exception as e:
570593
# Match original ToolUseLLM behavior - just log and continue
571594
self.logger.error(f"[_poll_tool_futures] Error adding request {req_id}: {e}")
595+
else:
596+
# If we can't make a new request, this tool execution is complete
597+
completed_outputs.append(tracking["concat_outputs"][req_id])
572598

573599
dict_keys_to_delete.append(req_id)
574600

575601
for req_id in dict_keys_to_delete:
576-
if req_id in tracking["pending_tool_futures"]:
577-
del tracking["pending_tool_futures"][req_id]
602+
tracking["pending_tool_futures"].pop(req_id, None)
603+
604+
return completed_outputs
578605

579606
def init_process_group(
580607
self,

scripts/train/debug/large_test_script.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ uv run python mason.py \
1212
--priority urgent \
1313
--preemptible \
1414
--num_nodes 2 \
15-
--description "rlvr ace fn and og ocr stdio from base with perf penalty" \
15+
--description "Large (multi-node) test script." \
1616
--max_retries 0 \
1717
--env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
1818
--budget ai2/oe-adapt \
@@ -39,7 +39,7 @@ uv run python mason.py \
3939
--stop_strings "</answer>" \
4040
--non_stop_penalty False \
4141
--temperature 1.0 \
42-
--verbose False \
42+
--verbose False \
4343
--ground_truths_key ground_truth \
4444
--sft_messages_key messages \
4545
--total_episodes 10_000 \

scripts/train/debug/single_gpu_integration_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ uv run python mason.py \
1111
--cluster ai2/augusta-google-1 \
1212
--cluster ai2/saturn-cirrascale \
1313
--image "$BEAKER_IMAGE" \
14+
--description "Single GPU on Beaker integration test." \
1415
--pure_docker_mode \
1516
--workspace ai2/open-instruct-dev \
1617
--priority high \

scripts/train/debug/single_gpu_on_beaker.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ uv run python mason.py \
1111
--cluster ai2/saturn-cirrascale \
1212
--cluster ai2/ceres-cirrascale \
1313
--image "$BEAKER_IMAGE" \
14+
--description "Single GPU on Beaker test script." \
1415
--pure_docker_mode \
1516
--workspace ai2/open-instruct-dev \
1617
--priority urgent \

scripts/train/debug/tool_grpo_fast.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ uv run python mason.py \
1414
--cluster ai2/augusta-google-1 \
1515
--cluster ai2/saturn-cirrascale \
1616
--image "$BEAKER_IMAGE" \
17+
--description "Single GPU on Beaker with tool use test script." \
1718
--pure_docker_mode \
1819
--workspace ai2/tulu-thinker \
1920
--priority high \

0 commit comments

Comments
 (0)