Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions atom/benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,13 @@ def sample_random_requests(
use_chat_template: bool = False,
apply_chat_template_fn: Callable = lambda x: x,
) -> List[Tuple[str, int, int]]:
# EXPERIMENT: ATOM_BENCH_IDENTICAL_PROMPTS=1 -> generate exactly ONE prompt
# via the single-prompt (gen_n=1) RNG path and replicate it num_prompts times.
# Reproduces the SAME fixed request used elsewhere (run_bench_1req, same seed)
# so the rebalance map matches the load exactly.
identical = os.getenv("ATOM_BENCH_IDENTICAL_PROMPTS", "") == "1"
gen_n = 1 if identical else num_prompts

prefix_token_ids = np.random.randint(
0, tokenizer.vocab_size, size=prefix_len
).tolist()
Expand All @@ -117,16 +124,16 @@ def sample_random_requests(
def sample_uniform(seq_len):
lower = int(seq_len * range_ratio)
upper = seq_len
seq_lens = np.random.randint(lower, upper + 1, size=num_prompts).tolist()
seq_lens = np.random.randint(lower, upper + 1, size=gen_n).tolist()
return seq_lens

input_lens = sample_uniform(input_len)
output_lens = sample_uniform(output_len)
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
offsets = np.random.randint(0, tokenizer.vocab_size, size=gen_n)

input_requests = []
mismatches = []
for i in range(num_prompts):
for i in range(gen_n):
tgt_prompt_len = prefix_len + input_lens[i]
prompt_token_ids = prefix_token_ids + [
(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])
Expand Down Expand Up @@ -156,6 +163,10 @@ def sample_uniform(seq_len):
mismatches.append(prompt_len - tgt_prompt_len)
input_requests.append((prompt, prompt_len, output_lens[i], None))

if identical:
# replicate the single generated prompt to the requested count
input_requests = [input_requests[0]] * num_prompts

header_str = f'{"-"*19} Input/Output Length Statistics {"-"*19}'
print(header_str)
print(
Expand Down
9 changes: 9 additions & 0 deletions atom/model_engine/engine_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class EngineUtilityHandler:
"stop_profile": "_handle_stop_profile",
"get_mtp_stats": "_handle_get_mtp_stats",
"get_mtp_statistics": "_handle_get_mtp_statistics",
"offline_eplb_rebalance": "_handle_offline_eplb_rebalance",
}

def __init__(
Expand Down Expand Up @@ -268,3 +269,11 @@ def _handle_get_mtp_statistics(self, args: dict):
self.output_queue.put_nowait(
("UTILITY_RESPONSE", {"cmd": "get_mtp_statistics", "result": result})
)

# ------------------------------------------------------------------
# EPLB expert-load statistics
# ------------------------------------------------------------------

def _handle_offline_eplb_rebalance(self, args: dict):
"""Trigger a one-shot offline EPLB rebalance plan (fire-and-forget)."""
self.runner_mgr.call_func("trigger_offline_eplb_rebalance", wait_out=True)
3 changes: 3 additions & 0 deletions atom/model_engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,9 @@ def get_mtp_statistics(self, timeout: float = 30.0) -> Dict[str, Any]:
},
}

def offline_eplb_rebalance(self):
self.core_mgr.send_utility_command("offline_eplb_rebalance")


class InputOutputProcessor:

Expand Down
7 changes: 7 additions & 0 deletions atom/model_engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from atom.model_loader.loader import load_model
from atom.model_ops.rejection_sampler import RejectionSampler
from atom.model_ops.sampler import SAMPLER_EPS, Sampler
from atom.model_ops.eplb_stats import get_expert_load_monitor
from atom.spec_decode.eagle import EagleProposer
from atom.utils import (
CpuGpuBuffer,
Expand Down Expand Up @@ -995,6 +996,11 @@ def stop_profiler(self):
)
return {"trace_dir": self.profiler_dir, "elapsed": elapsed}

def trigger_offline_eplb_rebalance(self):
"""Generate a one-shot offline EPLB rebalance plan from collected stats."""
get_expert_load_monitor().trigger_offline_rebalance(reason="utility_command")
return True

def debug(self, *args: Any):
if self.rank == 0:
logger.info(*args)
Expand Down Expand Up @@ -2166,6 +2172,7 @@ def forward(self, batch: ScheduledBatch) -> ScheduledBatchOutput:
hidden_states,
needs_independent_noise=needs_independent_noise,
)
get_expert_load_monitor().step(is_dummy_run=batch.is_dummy_run)
reset_forward_context()
return fwd_output

Expand Down
Loading
Loading