Skip to content

Commit e7b397d

Browse files
Merge branch 'main' into update-vllm
2 parents 5d429f8 + 3a21f22 commit e7b397d

File tree

9 files changed

+664
-829
lines changed

9 files changed

+664
-829
lines changed

mason.py

Lines changed: 229 additions & 327 deletions
Large diffs are not rendered by default.

open_instruct/grpo_fast.py

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -625,15 +625,6 @@ def load(self, path: str, map_location=None):
625625
np.random.seed(worker_seed)
626626
random.seed(worker_seed)
627627

628-
torch.distributed.init_process_group(
629-
backend="nccl",
630-
init_method="env://",
631-
world_size=self.world_size,
632-
rank=self.rank,
633-
timeout=timedelta(minutes=args.backend_timeout),
634-
device_id=torch.device("cuda", self.local_rank),
635-
)
636-
637628
deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout))
638629

639630
ds_config = get_train_ds_config(offload=False, adam_offload=False, stage=args.deepspeed_stage, bf16=True)
@@ -2312,28 +2303,6 @@ def weight_sync_thread(
23122303
logger.info("[Weight Sync Thread] 🛑 Stopping weight sync thread")
23132304

23142305

2315-
def generate_thread(args, vllm_engines, resume_training_step, stop_event, generate_metrics_Q):
2316-
"""Thread function that repeatedly calls process_from_queue on vllm engines."""
2317-
logger.info("[Generate Thread] 🚀 Starting generation thread")
2318-
while not stop_event.is_set():
2319-
with Timer("🔥 Generation time") as timer:
2320-
processed_results, _ = ray_get_with_progress(
2321-
[engine.process_from_queue.remote(timeout=20) for engine in vllm_engines],
2322-
desc="[Generate Thread] Waiting for vLLM engines to process",
2323-
enable=args.verbose,
2324-
)
2325-
num_processed = sum(int(result) for result in processed_results)
2326-
# Suppress timing output if nothing was processed
2327-
if num_processed == 0:
2328-
timer.noop = True
2329-
if num_processed > 0:
2330-
try:
2331-
generate_metrics_Q.put_nowait({"time/generation": timer.duration})
2332-
except Full:
2333-
logger.warning("[Generate Thread] generate metrics queue full, skipping metric")
2334-
logger.info("[Generate Thread] 🛑 Stopping generation thread")
2335-
2336-
23372306
def one_training_step(
23382307
args: Args,
23392308
policy_group: ModelGroup,
@@ -2682,7 +2651,6 @@ def cleanup_training_resources(
26822651
actor_manager: ActorManager,
26832652
) -> None:
26842653
"""Clean up all training resources including threads and Ray queues."""
2685-
# Signal generate_thread to stop
26862654
stop_event.set()
26872655

26882656
logger.info("Signaling all actors to stop...")
@@ -2791,14 +2759,13 @@ def run_training(
27912759
model_dims,
27922760
)
27932761

2794-
logger.info("======== ✅ generation thread starts =========")
2795-
generation_future = executor.submit(
2796-
generate_thread, args, vllm_engines, resume_training_step, stop_event, generate_metrics_Q
2797-
)
2798-
2799-
# setup health check function to check that everything is still alive
28002762
def health_check_fn():
2801-
[f.result() for f in [packing_future, generation_future, weight_sync_thread_future] if f.done()]
2763+
[f.result() for f in [packing_future, weight_sync_thread_future] if f.done()]
2764+
ray_get_with_progress(
2765+
[engine.check_background_threads.remote() for engine in vllm_engines],
2766+
desc="Checking vLLM engine health",
2767+
enable=False,
2768+
)
28022769

28032770
# Send initial data to ensure we have a N-step offset.
28042771
for _ in range(args.async_steps):
@@ -2835,7 +2802,9 @@ def health_check_fn():
28352802
)
28362803

28372804
# Check if any of the threads have raised an exception.
2805+
health_check_start = time.perf_counter()
28382806
health_check_fn()
2807+
health_check_time = time.perf_counter() - health_check_start
28392808

28402809
logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}")
28412810
weight_sync_trigger_event.set()
@@ -2859,7 +2828,6 @@ def health_check_fn():
28592828
is_eval=True,
28602829
)
28612830

2862-
# The generate_thread is now handling vLLM processing asynchronously
28632831
collated_data, data_thread_metrics, num_total_tokens, num_step_tokens, prompt_lengths, response_lengths = (
28642832
load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn)
28652833
)
@@ -2872,6 +2840,8 @@ def health_check_fn():
28722840
except Empty:
28732841
logger.info("[Main Thread] didn't get train generation metrics")
28742842

2843+
data_thread_metrics["time/health_check"] = health_check_time
2844+
28752845
one_training_step(
28762846
args,
28772847
policy_group,

open_instruct/search_utils/gpqa_eval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
max_output_len=args.model_len, # Explicitly set a custom max context length
4949
gpu_memory_utilization=0.95,
5050
num_gpus=1,
51-
enable_sleep_mode=False,
5251
noset_visible_devices=ray_noset_visible_devices(),
5352
)
5453

open_instruct/vllm_utils3.py

Lines changed: 34 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import dataclasses
1919
import os
2020
import queue
21+
import threading
2122
import time
2223
from collections import defaultdict
2324
from concurrent import futures
@@ -395,10 +396,15 @@ def __init__(
395396
self._should_stop_value = False
396397
self._should_stop_timeout_s = 5
397398

398-
self._executor = futures.ThreadPoolExecutor(max_workers=1)
399-
self._prefetch_future = self._executor.submit(self._prefetch_worker)
399+
# Initialize instance variables before starting threads
400400
self.tracking = _init_tool_tracking()
401401
self.request_outputs = {}
402+
self._threads_started = threading.Event()
403+
404+
# Start background threads
405+
self._executor = futures.ThreadPoolExecutor(max_workers=2)
406+
self._prefetch_future = self._executor.submit(self._prefetch_worker)
407+
self._process_future = self._executor.submit(self._process_from_queue)
402408

403409
def get_model_dims_dict(self):
404410
"""Get only the model dimensions as a simple dict without loading weights."""
@@ -431,8 +437,9 @@ def _should_stop(self) -> bool:
431437

432438
def _prefetch_worker(self, sleep_length_s: int = 1):
433439
"""Background worker that prefetches requests until we have enough buffered."""
440+
self._threads_started.set()
434441
while True:
435-
if self._should_stop():
442+
if not self.inflight_updates and self._should_stop():
436443
time.sleep(sleep_length_s)
437444
continue
438445
current_unfinished = self.llm_engine.get_num_unfinished_requests()
@@ -456,58 +463,18 @@ def _insert_result_to_queue(self, result, is_eval: bool):
456463
results_queue = self.eval_results_queue if is_eval else self.results_queue
457464
results_queue.put(result)
458465

459-
def _should_exit(self) -> bool:
460-
"""Determine if the processing loop should exit.
461-
462-
Returns:
463-
bool: True if the loop should exit, False otherwise.
464-
"""
465-
# Check stop condition first (cheapest check)
466-
stop_requested = self._should_stop()
467-
468-
# Case 1: inflight_updates enabled and stop requested - exit immediately
469-
if self.inflight_updates and stop_requested:
470-
return True
471-
472-
# Now check for pending work (only if needed)
473-
if stop_requested:
474-
# Need to check if we have pending work
475-
pending_tools = len(self.tracking["pending_tool_futures"])
476-
unfinished = self.llm_engine.get_num_unfinished_requests()
477-
478-
# Case 2: stop requested and no pending work - exit
479-
if pending_tools == 0 and unfinished == 0:
480-
return True
481-
# Otherwise, we have pending work and should continue
482-
return False
483-
484-
# No stop requested - check if there's any work to do
485-
pending_tools = len(self.tracking["pending_tool_futures"])
486-
unfinished = self.llm_engine.get_num_unfinished_requests()
487-
488-
# Case 3: no work left at all - exit
489-
if pending_tools == 0 and unfinished == 0:
490-
return True
491-
492-
# Otherwise, continue processing
493-
return False
494-
495-
def process_from_queue(self, timeout: float = 60.0):
466+
def _process_from_queue(self, timeout: float = 60.0):
496467
"""Run generation loop using LLMEngine directly, with optional tool support.
497468
498-
Runs continuously until should_stop is set, periodically adding new requests
499-
and yielding control to allow weight synchronization.
469+
Runs continuously in a background thread, processing requests from the engine.
500470
501471
Returns:
502472
int: Number of requests processed
503473
"""
504-
505-
# Use persistent instance variables for tracking and outputs
506-
# This ensures state is maintained across multiple calls
507474
total_processed = 0
508475
iteration_count = 0
509476

510-
while not self._should_exit():
477+
while True:
511478
iteration_count += 1
512479

513480
# Health check: ensure prefetch worker is alive. This will raise if it has crashed.
@@ -558,17 +525,7 @@ def process_from_queue(self, timeout: float = 60.0):
558525
total_processed += self._finalize_sub_request(
559526
output.request_id, output, complete_output, current_time
560527
)
561-
562-
if self.verbose and iteration_count % 100 == 0:
563-
final_unfinished = self.llm_engine.get_num_unfinished_requests()
564-
pending_tools = len(self.tracking["pending_tool_futures"])
565-
self.logger.info(
566-
f"process_from_queue iteration {iteration_count}: unfinished={final_unfinished}, pending_tools={pending_tools}"
567-
)
568-
569-
# If we have only pending tools but no unfinished requests, sleep briefly
570-
# to let pending tools complete before the next iteration
571-
if self.llm_engine.get_num_unfinished_requests() == 0 and len(self.tracking["pending_tool_futures"]) > 0:
528+
if self.llm_engine.get_num_unfinished_requests() == 0:
572529
time.sleep(1)
573530

574531
return total_processed
@@ -870,10 +827,22 @@ def init_process_group(
870827
args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray, timeout_minutes),
871828
)
872829

830+
def _maybe_drain_requests(self, sleep_s: float = 0.1):
831+
while not self.inflight_updates:
832+
pending_tools = len(self.tracking["pending_tool_futures"])
833+
unfinished = self.llm_engine.get_num_unfinished_requests()
834+
835+
if pending_tools == 0 and unfinished == 0:
836+
break
837+
838+
time.sleep(sleep_s)
839+
873840
def update_weight(self, name, dtype, shape, empty_cache=False):
841+
self._maybe_drain_requests()
874842
return self.llm_engine.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
875843

876844
def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False):
845+
self._maybe_drain_requests()
877846
return self.llm_engine.collective_rpc(
878847
"update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)
879848
)
@@ -888,8 +857,15 @@ def wake_up(self, tags: Optional[list[str]] = None):
888857
self.llm_engine.wake_up(tags)
889858

890859
def ready(self):
860+
self._threads_started.wait(timeout=30)
891861
return True
892862

863+
def check_background_threads(self):
864+
if self._prefetch_future.done():
865+
self._prefetch_future.result()
866+
if self._process_future.done():
867+
self._process_future.result()
868+
893869
def get_kv_cache_info(self):
894870
"""Get KV cache max concurrency from the vLLM engine."""
895871
kv_cache_specs = self.llm_engine.model_executor.get_kv_cache_specs()
@@ -954,7 +930,6 @@ def create_vllm_engines(
954930
vllm_gpu_memory_utilization: float = 0.9,
955931
single_gpu_mode: bool = False,
956932
pg: Optional[ray.util.placement_group] = None,
957-
vllm_enable_sleep=False,
958933
tools: Optional[Dict[str, Tool]] = None,
959934
max_tool_calls: List[int] = [5],
960935
prompt_queue=None,
@@ -1037,7 +1012,6 @@ def create_vllm_engines(
10371012
gpu_memory_utilization=vllm_gpu_memory_utilization,
10381013
bundle_indices=bundle_indices,
10391014
num_gpus=0.2 if use_hybrid_engine else 1,
1040-
enable_sleep_mode=vllm_enable_sleep,
10411015
noset_visible_devices=ray_noset_visible_devices(),
10421016
prompt_queue=prompt_queue,
10431017
results_queue=results_queue,
@@ -1053,83 +1027,6 @@ def create_vllm_engines(
10531027
)
10541028
)
10551029

1056-
# Verify engines initialized successfully
1057-
try:
1058-
ray_get_with_progress(
1059-
[engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=300
1060-
)
1061-
except TimeoutError as e:
1062-
logger.error(f"vLLM engines failed to initialize: {e}")
1063-
# Kill partially initialized actors before raising
1064-
for engine in vllm_engines:
1065-
ray.kill(engine)
1066-
raise RuntimeError(f"vLLM engine initialization timed out: {e}")
1067-
1068-
if vllm_enable_sleep:
1069-
batch_vllm_engine_call(vllm_engines, "sleep", rank_0_only=False)
1030+
ray_get_with_progress([engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=300)
10701031

10711032
return vllm_engines
1072-
1073-
1074-
def batch_vllm_engine_call(engines: List[Any], method_name: str, *args, rank_0_only: bool = True, **kwargs):
1075-
"""
1076-
Batch call a method on multiple vLLM engines.
1077-
Args:
1078-
engines: List of vLLM engine instances
1079-
method_name: Name of the method to call
1080-
rank_0_only: Only execute on rank 0 if True
1081-
*args: Positional arguments to pass to the method
1082-
**kwargs: Keyword arguments to pass to the method
1083-
Returns:
1084-
List of results from ray.get() if on rank 0, None otherwise
1085-
"""
1086-
import torch
1087-
1088-
if rank_0_only and torch.distributed.get_rank() != 0:
1089-
return None
1090-
1091-
refs = []
1092-
for engine in engines:
1093-
method = getattr(engine, method_name)
1094-
refs.append(method.remote(*args, **kwargs))
1095-
1096-
return ray.get(refs)
1097-
1098-
1099-
if __name__ == "__main__":
1100-
num_engines = 1
1101-
tensor_parallel_size = 1
1102-
world_size = num_engines * tensor_parallel_size + 1
1103-
vllm_engines = create_vllm_engines(
1104-
num_engines=num_engines,
1105-
tensor_parallel_size=tensor_parallel_size,
1106-
enforce_eager=True,
1107-
pretrain="facebook/opt-125m",
1108-
revision="main",
1109-
seed=42,
1110-
enable_prefix_caching=False,
1111-
max_model_len=1024,
1112-
)
1113-
llm = vllm_engines[0]
1114-
from vllm.utils import get_ip, get_open_port
1115-
1116-
master_address = get_ip()
1117-
master_port = get_open_port()
1118-
backend = "gloo"
1119-
1120-
refs = [
1121-
engine.init_process_group.remote(
1122-
master_address, master_port, i * tensor_parallel_size + 1, world_size, "openrlhf", backend=backend
1123-
)
1124-
for i, engine in enumerate(vllm_engines)
1125-
]
1126-
model_update_group = init_process_group(
1127-
backend=backend,
1128-
init_method=f"tcp://{master_address}:{master_port}",
1129-
world_size=world_size,
1130-
rank=0,
1131-
group_name="openrlhf",
1132-
)
1133-
ray.get(refs)
1134-
output = ray.get(llm.generate.remote("San Franciso is a"))
1135-
logger.info(f"output: {output}")

0 commit comments

Comments
 (0)