Skip to content

Commit f455ef0

Browse files
finbarrtimbersroot
andauthored
Adds a dashboard to ActorManager that makes it easier to track what's going on. (#944)
* First version of dashbaord. * Clean up logging. * Now, our dashboard shows the queue sizes. * Dashboard runs and imports cache utils. * Moved code. * Fixed imports. * Now, code runs. * Added inference batch size to dashboard. * Added loggign. * Fixed inference batch size calculation. * Cleaned up PR> * Added actor manager. * Cleaned up code. * Cleaned up PR. * Update code. * Added dashboard * Cleaned up PR.: * Now, tests pass. * Cleaned up PR. * Fixed typo. * Fixed tests. * Ran linter * Now, we use fqdn, not the short hostname in the URL. * Updated reporting of batch size. * Added a cleanup method and passes args through to ActorManager. * minor bugfix * lint * Added port arg. --------- Co-authored-by: root <[email protected]>
1 parent 5301917 commit f455ef0

File tree

11 files changed

+832
-48
lines changed

11 files changed

+832
-48
lines changed

open_instruct/actor_manager.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright 2024 The AllenAI Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""ActorManager for controlling evaluation and weight updates across all LLMRayActors."""
16+
17+
import collections
18+
import socket
19+
import threading
20+
import time
21+
from datetime import datetime
22+
from pathlib import Path
23+
24+
import uvicorn
25+
from fastapi import FastAPI
26+
from fastapi.responses import HTMLResponse
27+
from fastapi.staticfiles import StaticFiles
28+
29+
from open_instruct import logger_utils
30+
31+
32+
def find_free_port():
33+
"""Find and return a free port number."""
34+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
35+
s.bind(("", 0))
36+
s.listen(1)
37+
port = s.getsockname()[1]
38+
return port
39+
40+
41+
class ActorManager:
42+
"""Centralized manager for controlling evaluation and weight updates across all LLMRayActors."""
43+
44+
def __init__(self, queues: dict, args):
45+
self._should_stop = False
46+
self._last_updated = datetime.now()
47+
self._dashboard_port = None
48+
self._queues = queues or {}
49+
self._queue_sizes = {}
50+
self._queue_info = {}
51+
self._sample_window = 100
52+
self._token_history = collections.deque(maxlen=self._sample_window)
53+
self._total_prefill_tokens = 0
54+
self._total_decode_tokens = 0
55+
self._training_step_history = collections.deque(maxlen=self._sample_window)
56+
self._generation_batch_history = collections.deque(maxlen=self._sample_window)
57+
self._kv_cache_max_concurrency = None
58+
self._args = args
59+
if self._args.enable_queue_dashboard:
60+
self._setup_queue_monitoring()
61+
self._start_dashboard()
62+
63+
def _setup_queue_monitoring(self):
64+
"""Setup queue monitoring with background polling thread."""
65+
for queue_name, q in self._queues.items():
66+
self._queue_info[queue_name] = {"maxsize": q.maxsize if hasattr(q, "maxsize") else 0, "queue": q}
67+
self._queue_sizes[queue_name] = 0
68+
69+
self._polling_active = True
70+
self._poll_thread = threading.Thread(target=self._poll_queue_sizes, daemon=True)
71+
self._poll_thread.start()
72+
73+
def _poll_queue_sizes(self):
74+
"""Background thread to poll queue sizes."""
75+
while self._polling_active:
76+
for queue_name, info in self._queue_info.items():
77+
current_size = info["queue"].size()
78+
self._queue_sizes[queue_name] = current_size
79+
time.sleep(0.5)
80+
81+
def _start_dashboard(self):
82+
"""Start the FastAPI dashboard server in a background thread."""
83+
if self._args.queue_dashboard_port is None:
84+
self._dashboard_port = find_free_port()
85+
else:
86+
self._dashboard_port = self._args.queue_dashboard_port
87+
app = FastAPI(title="ActorManager Dashboard")
88+
89+
static_dir = Path(__file__).parent / "static"
90+
app.mount("/static", StaticFiles(directory=str(static_dir)), name="static")
91+
92+
@app.get("/", response_class=HTMLResponse)
93+
async def dashboard():
94+
"""Serve the HTML dashboard."""
95+
html_path = Path(__file__).parent / "static" / "dashboard.html"
96+
with open(html_path, "r") as f:
97+
return f.read()
98+
99+
@app.get("/api/status")
100+
async def api_status():
101+
"""Return the current status as JSON."""
102+
queues_data = {
103+
queue_name: {"current": self._queue_sizes.get(queue_name, 0), "maxsize": info["maxsize"]}
104+
for queue_name, info in self._queue_info.items()
105+
}
106+
107+
return {
108+
"should_stop": self._should_stop,
109+
"last_updated": self._last_updated.isoformat(),
110+
"queues": queues_data,
111+
"token_stats": self.get_token_stats(),
112+
"timing_stats": self.get_timing_stats(),
113+
"kv_cache_max_concurrency": self._kv_cache_max_concurrency,
114+
# This is less confusing to users.
115+
"inference_batch_size": self._args.inference_batch_size * self._args.num_samples_per_prompt_rollout,
116+
}
117+
118+
def run_server():
119+
uvicorn.run(app, host="0.0.0.0", port=self._dashboard_port, log_level="error")
120+
121+
self._server_thread = threading.Thread(target=run_server, daemon=True)
122+
self._server_thread.start()
123+
124+
hostname = socket.getfqdn()
125+
126+
logger = logger_utils.setup_logger(__name__)
127+
logger.info(f"Dashboard server started at http://{hostname}:{self._dashboard_port}")
128+
129+
def set_should_stop(self, should_stop: bool):
130+
"""Set whether actors should stop processing."""
131+
self._should_stop = should_stop
132+
self._last_updated = datetime.now()
133+
134+
def should_stop(self) -> bool:
135+
"""Check if actors should stop processing."""
136+
return self._should_stop
137+
138+
def report_token_stats(self, prompt_tokens: int, generation_tokens: int):
139+
"""Report token statistics from main thread."""
140+
current_time = time.time()
141+
142+
self._total_prefill_tokens += prompt_tokens
143+
self._total_decode_tokens += generation_tokens
144+
145+
self._token_history.append(
146+
{"timestamp": current_time, "prompt_tokens": prompt_tokens, "generation_tokens": generation_tokens}
147+
)
148+
149+
def report_token_statistics(self, token_stats):
150+
"""Report token statistics using TokenStatistics object."""
151+
current_time = time.time()
152+
153+
self._total_prefill_tokens += token_stats.num_prompt_tokens
154+
self._total_decode_tokens += token_stats.num_response_tokens
155+
156+
self._token_history.append(
157+
{
158+
"timestamp": current_time,
159+
"prompt_tokens": token_stats.num_prompt_tokens,
160+
"generation_tokens": token_stats.num_response_tokens,
161+
}
162+
)
163+
164+
self._generation_batch_history.append(token_stats.generation_time)
165+
166+
def report_training_step_time(self, duration: float):
167+
"""Report the time taken for a training step."""
168+
self._training_step_history.append(duration)
169+
170+
def report_batch_generation_time(self, duration: float):
171+
"""Report the time taken to generate a batch of data."""
172+
self._generation_batch_history.append(duration)
173+
174+
def set_kv_cache_max_concurrency(self, max_concurrency: int):
175+
"""Set the KV cache max concurrency value."""
176+
self._kv_cache_max_concurrency = max_concurrency
177+
178+
def get_token_stats(self):
179+
"""Calculate and return current token statistics."""
180+
if not self._token_history:
181+
return {
182+
"total_prefill_tokens": self._total_prefill_tokens,
183+
"total_decode_tokens": self._total_decode_tokens,
184+
"prefill_tokens_per_sec": 0,
185+
"decode_tokens_per_sec": 0,
186+
"sample_count": 0,
187+
}
188+
189+
current_time = time.time()
190+
191+
window_prompt_tokens = 0
192+
window_generation_tokens = 0
193+
oldest_timestamp = self._token_history[0]["timestamp"]
194+
195+
for entry in self._token_history:
196+
window_prompt_tokens += entry["prompt_tokens"]
197+
window_generation_tokens += entry["generation_tokens"]
198+
199+
time_span = current_time - oldest_timestamp if len(self._token_history) > 1 else 1
200+
201+
prompt_tokens_per_sec = window_prompt_tokens / time_span if time_span > 0 else 0
202+
generation_tokens_per_sec = window_generation_tokens / time_span if time_span > 0 else 0
203+
204+
return {
205+
"total_prefill_tokens": self._total_prefill_tokens,
206+
"total_decode_tokens": self._total_decode_tokens,
207+
"prefill_tokens_per_sec": prompt_tokens_per_sec,
208+
"decode_tokens_per_sec": generation_tokens_per_sec,
209+
"sample_count": len(self._token_history),
210+
}
211+
212+
def get_timing_stats(self):
213+
"""Calculate and return current timing statistics."""
214+
avg_training_step_time = (
215+
sum(self._training_step_history) / len(self._training_step_history) if self._training_step_history else 0
216+
)
217+
218+
avg_batch_generation_time = (
219+
sum(self._generation_batch_history) / len(self._generation_batch_history)
220+
if self._generation_batch_history
221+
else 0
222+
)
223+
224+
return {
225+
"avg_training_step_time": avg_training_step_time,
226+
"avg_batch_generation_time": avg_batch_generation_time,
227+
"training_step_count": len(self._training_step_history),
228+
"batch_generation_count": len(self._generation_batch_history),
229+
}
230+
231+
def get_dashboard_port(self):
232+
"""Get the port number where the dashboard is running."""
233+
return self._dashboard_port
234+
235+
def cleanup(self):
236+
"""Clean up resources including stopping the polling thread."""
237+
logger = logger_utils.setup_logger(__name__)
238+
239+
# Stop the polling thread if dashboard was enabled
240+
if self._args.enable_queue_dashboard:
241+
logger.info("Stopping queue polling thread...")
242+
self._polling_active = False
243+
# Wait for the thread to finish with a timeout
244+
self._poll_thread.join(timeout=2.0)

open_instruct/benchmark_generators.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ray.util import queue as ray_queue
2929

3030
from open_instruct import dataset_transformation, grpo_fast, logger_utils, model_utils, utils, vllm_utils3
31+
from open_instruct.actor_manager import ActorManager
3132
from open_instruct.queue_types import PromptRequest
3233

3334
# For FLOPS, we assume bf16 and ignore sparsity.
@@ -607,7 +608,8 @@ def setup_vllm_engines(
607608
param_prompt_Q = ray_queue.Queue(maxsize=10)
608609
inference_results_Q = ray_queue.Queue(maxsize=10)
609610

610-
actor_manager = vllm_utils3.ActorManager.remote()
611+
queues_to_monitor = {"Param Prompt Queue": param_prompt_Q, "Inference Results Queue": inference_results_Q}
612+
actor_manager = ray.remote(ActorManager).remote(queues_to_monitor, args)
611613

612614
vllm_engines = vllm_utils3.create_vllm_engines(
613615
num_engines=args.vllm_num_engines,

0 commit comments

Comments
 (0)