From 68a3c1dc902a71071b439a7640342d605475c664 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 18 Aug 2025 20:57:36 +0000 Subject: [PATCH 01/13] add graph runner support with torch compile on CPU --- docs/platforms/cpu_server.md | 7 +- python/sglang/srt/disaggregation/prefill.py | 2 +- .../sglang/srt/distributed/parallel_state.py | 7 +- .../srt/layers/attention/intel_amx_backend.py | 3 + python/sglang/srt/layers/quantization/fp8.py | 3 + .../srt/layers/quantization/w8a8_int8.py | 12 +- python/sglang/srt/managers/scheduler.py | 21 +- .../srt/managers/scheduler_metrics_mixin.py | 4 +- .../scheduler_output_processor_mixin.py | 8 +- python/sglang/srt/managers/tp_worker.py | 8 +- .../srt/managers/tp_worker_overlap_thread.py | 8 +- .../srt/model_executor/cpu_graph_runner.py | 702 ++++++++++++++++++ .../srt/model_executor/forward_batch_info.py | 3 + .../sglang/srt/model_executor/model_runner.py | 56 +- python/sglang/srt/speculative/eagle_worker.py | 12 +- python/sglang/srt/utils.py | 10 +- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 10 +- test/srt/test_intel_amx_attention_backend.py | 71 +- 18 files changed, 879 insertions(+), 68 deletions(-) create mode 100644 python/sglang/srt/model_executor/cpu_graph_runner.py diff --git a/docs/platforms/cpu_server.md b/docs/platforms/cpu_server.md index 348bf893695b..4e91e7b8839f 100644 --- a/docs/platforms/cpu_server.md +++ b/docs/platforms/cpu_server.md @@ -134,7 +134,12 @@ Notes: export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253" ``` -3. A warmup step is automatically triggered when the service is started. +3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`. + To specify the maximum batch size when using torch compile, set the flag `--torch-compile-max-bs`. + For example, `--enable-torch-compile --torch-compile-max-bs 4` means using torch compile and setting the + maximum batch size to 4. + +4. A warmup step is automatically triggered when the service is started. The server is ready when you see the log `The server is fired up and ready to roll!`. ## Benchmarking with Requests diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 5f5d0ebc6abd..c06f4176a610 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -774,7 +774,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): extend_input_len_per_req=None, extend_logprob_start_len_per_req=None, bid=bids[next_mb_id], - can_run_cuda_graph=result.can_run_cuda_graph, + can_run_graph=result.can_run_graph, ) self.process_batch_result_disagg_prefill( mbs[next_mb_id], output_result diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index a8a8d20f667d..39e40ca0a31c 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -60,6 +60,9 @@ class GraphCaptureContext: TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) +# use int value instead of ReduceOp.SUM to support torch compile +REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM) + def _split_tensor_dict( tensor_dict: Dict[str, Union[torch.Tensor, Any]] @@ -483,9 +486,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if input_.is_cpu: if is_shm_available(input_.dtype, self.world_size, self.local_size): - torch.ops.sgl_kernel.shm_allreduce( - input_, torch.distributed.ReduceOp.SUM - ) + torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM) else: torch.distributed.all_reduce(input_, group=self.device_group) return input_ diff --git a/python/sglang/srt/layers/attention/intel_amx_backend.py b/python/sglang/srt/layers/attention/intel_amx_backend.py index 9f2f7ece4d87..39e5c7428adc 100644 --- a/python/sglang/srt/layers/attention/intel_amx_backend.py +++ b/python/sglang/srt/layers/attention/intel_amx_backend.py @@ -49,6 +49,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): max_extend_len = torch.max(forward_batch.extend_seq_lens).item() self.forward_metadata = (attn_logits, max_extend_len) + def get_graph_seq_len_fill_value(self): + return 1 + def forward_extend( self, q, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index f2e07b515a54..9c35add23542 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -349,6 +349,9 @@ def process_weights_after_loading(self, layer: Module) -> None: _is_cpu_amx_available ), "Fp8LinearMethod on CPU requires that CPU has AMX support" _amx_process_weight_after_loading(layer, ["weight"]) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) return else: weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index abcf334e00eb..4dc20e8864b4 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -339,9 +339,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: _is_cpu_amx_available ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support" _amx_process_weight_after_loading(layer, ["weight"]) - return - - layer.weight = Parameter(layer.weight.t(), requires_grad=False) + else: + layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) def create_weights( @@ -472,10 +471,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: _is_cpu_amx_available ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support" _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) - return - - layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) - layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) + else: + layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) + layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) layer.w13_weight_scale = Parameter( layer.w13_weight_scale.data, requires_grad=False ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 91e02b08e795..15659cd4158b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -177,7 +177,7 @@ class GenerationBatchResult: extend_input_len_per_req: List[int] extend_logprob_start_len_per_req: List[int] bid: int - can_run_cuda_graph: bool + can_run_graph: bool @dataclass @@ -394,7 +394,7 @@ def __init__( f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " f"context_len={self.model_config.context_len}, " - f"available_gpu_mem={avail_mem:.2f} GB" + f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB" ) # Init memory pool and cache @@ -922,7 +922,7 @@ def event_loop_pp(self): "extend_logprob_start_len_per_req", None ), bid=bids[next_mb_id], - can_run_cuda_graph=result.can_run_cuda_graph, + can_run_graph=result.can_run_graph, ) self.process_batch_result(mbs[next_mb_id], output_result) last_mbs[next_mb_id] = mbs[next_mb_id] @@ -1735,11 +1735,11 @@ def run_batch( model_worker_batch.hicache_consumer_index ) if self.pp_group.is_last_rank: - logits_output, next_token_ids, can_run_cuda_graph = ( + logits_output, next_token_ids, can_run_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) ) else: - pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = ( + pp_hidden_states_proxy_tensors, _, can_run_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) ) bid = model_worker_batch.bid @@ -1749,7 +1749,7 @@ def run_batch( next_token_ids, bid, num_accepted_tokens, - can_run_cuda_graph, + can_run_graph, ) = self.draft_worker.forward_batch_speculative_generation(batch) bs = batch.batch_size() self.spec_num_total_accepted_tokens += num_accepted_tokens + bs @@ -1784,7 +1784,7 @@ def run_batch( extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, bid=bid, - can_run_cuda_graph=can_run_cuda_graph, + can_run_graph=can_run_graph, ) else: # embedding or reward model model_worker_batch = batch.get_model_worker_batch() @@ -2245,10 +2245,9 @@ def get_internal_state(self, recv_req: GetInternalStateReq): "token_capacity": int(self.max_total_num_tokens), } - if not _is_cpu: - ret["memory_usage"]["cuda_graph"] = round( - self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2 - ) + ret["memory_usage"]["graph"] = round( + self.tp_worker.worker.model_runner.graph_mem_usage, 2 + ) if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: ret["avg_spec_accept_length"] = ( diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index a6497ffde5c1..34728362e977 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -130,7 +130,7 @@ def log_prefill_stats( self._publish_kv_events() def log_decode_stats( - self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None + self, can_run_graph: bool, running_batch: ScheduleBatch = None ): batch = running_batch or self.running_batch @@ -185,7 +185,7 @@ def log_decode_stats( msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " msg += ( - f"cuda graph: {can_run_cuda_graph}, " + f"{'cpu graph of torch.compile' if self.device == 'cpu' else 'cuda graph'}: {can_run_graph}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"#queue-req: {len(self.waiting_queue)}, " ) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index a86899f6e79b..758654cdfc69 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -197,15 +197,15 @@ def process_batch_result_decode( result: GenerationBatchResult, launch_done: Optional[threading.Event] = None, ): - logits_output, next_token_ids, can_run_cuda_graph = ( + logits_output, next_token_ids, can_run_graph = ( result.logits_output, result.next_token_ids, - result.can_run_cuda_graph, + result.can_run_graph, ) self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: - logits_output, next_token_ids, can_run_cuda_graph = ( + logits_output, next_token_ids, can_run_graph = ( self.tp_worker.resolve_last_batch_result(launch_done) ) next_token_logprobs = logits_output.next_token_logprobs @@ -293,7 +293,7 @@ def process_batch_result_decode( self.current_scheduler_metrics_enabled() and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): - self.log_decode_stats(can_run_cuda_graph, running_batch=batch) + self.log_decode_stats(can_run_graph, running_batch=batch) def add_input_logprob_return_values( self: Scheduler, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 77dac1ea6c68..fb4d0b115252 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -235,7 +235,7 @@ def forward_batch_generation( ) if self.pp_group.is_last_rank: - logits_output, can_run_cuda_graph = self.model_runner.forward( + logits_output, can_run_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors ) if launch_done is not None: @@ -248,13 +248,13 @@ def forward_batch_generation( logits_output, model_worker_batch ) - return logits_output, next_token_ids, can_run_cuda_graph + return logits_output, next_token_ids, can_run_graph else: - pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( + pp_proxy_tensors, can_run_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors, ) - return pp_proxy_tensors.tensors, None, can_run_cuda_graph + return pp_proxy_tensors.tensors, None, can_run_graph def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 674a941955cd..c4ed1bcc1993 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -172,7 +172,7 @@ def forward_thread_func_(self): # update the consumer index of hicache to the running batch self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) # Run forward - logits_output, next_token_ids, can_run_cuda_graph = ( + logits_output, next_token_ids, can_run_graph = ( self.worker.forward_batch_generation( model_worker_batch, model_worker_batch.launch_done ) @@ -201,7 +201,7 @@ def forward_thread_func_(self): copy_done.record() self.output_queue.put( - (copy_done, logits_output, next_token_ids, can_run_cuda_graph) + (copy_done, logits_output, next_token_ids, can_run_graph) ) def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None): @@ -209,7 +209,7 @@ def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = Non This function is called to resolve the last batch result and wait for the current batch to be launched. Used in overlap mode. """ - copy_done, logits_output, next_token_ids, can_run_cuda_graph = ( + copy_done, logits_output, next_token_ids, can_run_graph = ( self.output_queue.get() ) @@ -226,7 +226,7 @@ def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = Non logits_output.input_token_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() - return logits_output, next_token_ids, can_run_cuda_graph + return logits_output, next_token_ids, can_run_graph def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch diff --git a/python/sglang/srt/model_executor/cpu_graph_runner.py b/python/sglang/srt/model_executor/cpu_graph_runner.py new file mode 100644 index 000000000000..14dc9de89d11 --- /dev/null +++ b/python/sglang/srt/model_executor/cpu_graph_runner.py @@ -0,0 +1,702 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Run the model with cpu torch compile.""" + +# The implementation of CPUGraphRunner follows the CudaGraphRunner + +from __future__ import annotations + +import bisect +import logging +from contextlib import contextmanager +from typing import TYPE_CHECKING, Callable, Optional, Union + +import psutil +import torch +import tqdm + +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.parallel_state import GroupCoordinator +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, + PPProxyTensors, + enable_num_token_non_padded, +) +from sglang.srt.patch_torch import monkey_patch_torch_compile +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import ( + rank0_log, + require_attn_tp_gather, + require_gathered_buffer, + require_mlp_sync, + require_mlp_tp_gather, +) + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + + +@contextmanager +def patch_model( + model: torch.nn.Module, + enable_compile: bool, + num_tokens: int, + tp_group: GroupCoordinator, +): + """Patch the model to make it compatible with torch.compile""" + backup_ca_comm = None + + try: + if enable_compile: + backup_ca_comm = tp_group.ca_comm + # Use custom-allreduce here. + # We found the custom allreduce is much faster than the built-in allreduce in torch, + # even with ENABLE_INTRA_NODE_COMM=1. + # tp_group.ca_comm = None + yield torch.compile( + torch.no_grad()(model.forward), + dynamic=False, + ) + else: + yield model.forward + finally: + if enable_compile: + tp_group.ca_comm = backup_ca_comm + + +def set_torch_compile_config(): + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + torch._inductor.config.freezing = True + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + monkey_patch_torch_compile() + + +def get_batch_sizes_to_capture(model_runner: ModelRunner): + server_args = model_runner.server_args + # cpu torch compile only speeds up decoding by + # reducing python overhead when bs is small + capture_bs = list(range(1, 9)) + list(range(10, 17, 2)) + capture_bs = [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] + capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] + capture_bs = list(sorted(set(capture_bs))) + assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}" + return capture_bs + + +def register_fake_ops(): + """ + Registers fake/meta implementations for all custom sgl_kernel CPU operators + using torch.library.register_fake to support torch.compile + """ + + none_return_ops = [ + "shm_allreduce", + "bmm_cpu", + "fused_add_rmsnorm_cpu", + "decode_attention_cpu", + "extend_attention_cpu", + ] + for op in none_return_ops: + + @torch.library.register_fake(f"sgl_kernel::{op}") + def _(*args, **kwargs): + return + + for op in [ + "rmsnorm_cpu", + "l2norm_cpu", + "fused_experts_cpu", + "shared_expert_cpu", + ]: + + @torch.library.register_fake(f"sgl_kernel::{op}") + def _(input, *args, **kwargs): + return torch.empty_like(input) + + @torch.library.register_fake("sgl_kernel::qkv_proj_with_rope") + def _( + hidden_states, + q_a_proj_weight, + q_b_proj_weight, + kv_a_proj_weight, + w_kc, + q_a_layernorm_weight, + kv_a_layernorm_weight, + positions, + cos_sin_cache, + eps, + use_int8_w8a8, + use_fp8_w8a16, + q_a_proj_scale, + q_b_proj_scale, + kv_a_proj_scale, + is_vnni, + block_size, + ): + num_seqs = hidden_states.shape[0] + num_heads = w_kc.shape[0] + kv_lora_rank = w_kc.shape[1] + qk_rope_head_dim = kv_a_proj_weight.shape[0] - kv_lora_rank + q_input = torch.empty( + num_seqs, + num_heads, + kv_lora_rank + qk_rope_head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + k_input = torch.empty( + num_seqs, + 1, + kv_lora_rank + qk_rope_head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + v_input = k_input.narrow(-1, 0, kv_lora_rank) + return q_input, k_input, v_input + + @torch.library.register_fake("sgl_kernel::rotary_embedding_cpu") + def _(positions, query, key, head_size, cos_sin_cache, is_neox): + if query.ndim == 2: + return query, key + else: + return torch.empty_like(query), torch.empty_like(key) + + @torch.library.register_fake("sgl_kernel::qkv_proj_with_rope_fused_weight") + def _( + hidden_states, + q_a_proj_weight, + q_b_proj_weight, + w_kc, + q_a_layernorm_weight, + kv_a_layernorm_weight, + positions, + cos_sin_cache, + eps, + use_int8_w8a8, + use_fp8_w8a16, + qkv_a_proj_scale, + q_b_proj_scale, + is_vnni, + block_size, + q_lora_rank, + kv_lora_rank, + qk_rope_head_dim, + ): + num_seqs = hidden_states.shape[0] + num_heads = w_kc.shape[0] + kv_lora_rank = w_kc.shape[1] + weight_chunks = torch.split( + q_a_proj_weight, [q_lora_rank, kv_lora_rank + qk_rope_head_dim], dim=0 + ) + qk_rope_head_dim = weight_chunks[1].shape[0] - kv_lora_rank + q_input = torch.empty( + num_seqs, + num_heads, + kv_lora_rank + qk_rope_head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + k_input = torch.empty( + num_seqs, + 1, + kv_lora_rank + qk_rope_head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + v_input = k_input.narrow(-1, 0, kv_lora_rank) + return q_input, k_input, v_input + + @torch.library.register_fake("sgl_kernel::weight_packed_linear") + def _(x, weight, bias, is_vnni): + return x.new_empty(x.shape[0], weight.shape[0]) + + @torch.library.register_fake("sgl_kernel::per_token_quant_int8_cpu") + def _(input): + M = input.shape[0] + K = input.shape[1] + Aq = input.new_empty(M, K, dtype=torch.int8) + As = input.new_empty(M, dtype=torch.float32) + return Aq, As + + @torch.library.register_fake("sgl_kernel::int8_scaled_mm_cpu") + def _(mat1, mat2, scales1, scales2, bias, out_dtype, is_vnni): + M = mat1.shape[0] + N = mat2.shape[0] + out = mat1.new_empty(M, N, dtype=out_dtype) + return out + + @torch.library.register_fake("sgl_kernel::grouped_topk_cpu") + def _( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, + num_fused_shared_experts, + routed_scaling_factor, + num_token_non_padded, + ): + num_tokens = hidden_states.shape[0] + shape = (num_tokens, topk) + device = hidden_states.device + topk_weights = torch.empty(shape, device=device, dtype=torch.float32) + topk_ids = torch.empty(shape, device=device, dtype=torch.int) + return topk_weights, topk_ids + + @torch.library.register_fake("sgl_kernel::biased_grouped_topk_cpu") + def _( + hidden_states, + gating_output, + correction_bias, + topk, + renormalize, + num_expert_group, + topk_group, + num_fused_shared_experts, + routed_scaling_factor, + num_token_non_padded, + ): + num_tokens = hidden_states.shape[0] + shape = (num_tokens, topk) + device = hidden_states.device + topk_weights = torch.empty(shape, device=device, dtype=torch.float32) + topk_ids = torch.empty(shape, device=device, dtype=torch.int) + return topk_weights, topk_ids + + @torch.library.register_fake("sgl_kernel::topk_sigmoid_cpu") + def _(hidden_states, gating_output, topk, renormalize): + num_tokens = hidden_states.shape[0] + shape = (num_tokens, topk) + return ( + torch.empty(shape, device=hidden_states.device, dtype=torch.float), + torch.empty(shape, device=hidden_states.device, dtype=torch.int), + ) + + @torch.library.register_fake("sgl_kernel::topk_softmax_cpu") + def _( + hidden_states, + gating_output, + topk, + renormalize, + ): + num_tokens = hidden_states.shape[0] + shape = (num_tokens, topk) + return ( + torch.empty(shape, device=hidden_states.device, dtype=torch.float), + torch.empty(shape, device=hidden_states.device, dtype=torch.int), + ) + + @torch.library.register_fake("sgl_kernel::silu_and_mul_cpu") + def _(input): + return input.new_empty(input.shape[0], input.shape[1] // 2) + + @torch.library.register_fake("sgl_kernel::int8_scaled_mm_with_quant") + def _( + mat1, + mat2, + scales2, + bias, + out_dtype, + is_vnni, + ): + M = mat1.shape[0] + N = mat2.shape[0] + return mat1.new_empty(M, N, dtype=out_dtype) + + @torch.library.register_fake("sgl_kernel::fp8_scaled_mm_cpu") + def _( + mat1, + mat2, + scales2, + block_size, + bias, + out_dtype, + is_vnni, + ): + M = mat1.shape[0] + N = mat2.shape[0] + return mat1.new_empty(M, N, dtype=out_dtype) + + +# TODO Remove unnecessary settings for CPUGraphRunner. +# Re-abstract the graph runner and restructure CPUGraphRunner to reuse the same logic. +class CPUGraphRunner: + """A CPUGraphRunner runs the forward pass of a model with cpu torch.compile.""" + + def __init__(self, model_runner: ModelRunner): + # Parse args + self.model_runner = model_runner + self.device = model_runner.device + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder + self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) + self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) + self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) + self.enable_two_batch_overlap = ( + model_runner.server_args.enable_two_batch_overlap + ) + self.speculative_algorithm = model_runner.server_args.speculative_algorithm + self.enable_profile_cuda_graph = ( + model_runner.server_args.enable_profile_cuda_graph + ) + self.tp_size = model_runner.server_args.tp_size + self.dp_size = model_runner.server_args.dp_size + self.pp_size = model_runner.server_args.pp_size + + self.capture_forward_mode = ForwardMode.DECODE + self.capture_hidden_mode = CaptureHiddenMode.NULL + self.num_tokens_per_bs = 1 + + # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup + if model_runner.server_args.enable_return_hidden_states: + self.capture_hidden_mode = CaptureHiddenMode.FULL + + assert ( + not self.model_runner.server_args.enable_lora + ), "CPUGraphRunner does not support LoRA yet." + assert ( + not self.enable_two_batch_overlap + ), "CPUGraphRunner does not support two batch overlap yet." + assert ( + not self.require_mlp_tp_gather + ), "CPUGraphRunner does not support MLP TP gather yet." + assert ( + not self.require_mlp_sync + ), "CPUGraphRunner does not support MLP sync yet." + assert ( + not self.require_gathered_buffer + ), "CPUGraphRunner does not support gathered buffer yet." + assert ( + model_runner.spec_algorithm == SpeculativeAlgorithm.NONE + ), "CPUGraphRunner does not support speculative inference yet." + # TODO add compile support for encoder-decoder models + assert ( + not self.is_encoder_decoder + ), "CPUGraphRunner does not support encoder-decoder models yet." + assert self.dp_size == 1, "CPUGraphRunner does not support DP yet." + assert self.pp_size == 1, "CPUGraphRunner does not support PP yet." + + # Batch sizes to capture + self.capture_bs = get_batch_sizes_to_capture(model_runner) + rank0_log(f"Capture cpu graph bs {self.capture_bs}") + # Attention backend + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + + self.seq_len_fill_value = ( + self.model_runner.attn_backend.get_graph_seq_len_fill_value() + ) + + if self.enable_torch_compile: + register_fake_ops() + set_torch_compile_config() + + # Graph inputs + with torch.device(self.device): + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int64) + self.seq_lens = torch.full( + (self.max_bs,), self.seq_len_fill_value, dtype=torch.int64 + ) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + self.num_token_non_padded = torch.zeros((1,), dtype=torch.int64) + self.custom_mask = torch.ones( + ( + (self.seq_lens.sum().item() + self.max_num_token) + * self.num_tokens_per_bs + ), + dtype=torch.bool, + device=self.device, + ) + + # Capture + try: + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture CPU graph failed: {e}\n{CPU_GRAPH_CAPTURE_FAILED_MSG}" + ) + + def can_run(self, forward_batch: ForwardBatch): + cpu_graph_bs = forward_batch.batch_size + + is_bs_supported = ( + cpu_graph_bs in self.graphs + if self.disable_padding + else cpu_graph_bs <= self.max_bs + ) + + requested_capture_hidden_mode = max( + forward_batch.capture_hidden_mode, + ( + forward_batch.spec_info.capture_hidden_mode + if getattr(forward_batch.spec_info, "capture_hidden_mode", None) + is not None + else CaptureHiddenMode.NULL + ), + ) + capture_hidden_mode_matches = ( + requested_capture_hidden_mode == CaptureHiddenMode.NULL + or requested_capture_hidden_mode == self.capture_hidden_mode + ) + + return is_bs_supported and capture_hidden_mode_matches + + def capture(self) -> None: + capture_range = ( + tqdm.tqdm(list(reversed(self.capture_bs))) + if get_tensor_model_parallel_rank() == 0 + else reversed(self.capture_bs) + ) + for bs in capture_range: + if get_tensor_model_parallel_rank() == 0: + avail_mem = psutil.virtual_memory().available / (1 << 30) + capture_range.set_description( + f"Capturing batches ({bs=} {avail_mem=:.2f} GB)" + ) + + with patch_model( + self.model_runner.model, + bs in self.capture_bs, + num_tokens=bs * self.num_tokens_per_bs, + tp_group=self.model_runner.tp_group, + ) as forward: + ( + graph, + output_buffers, + ) = self.capture_one_batch_size(bs, forward) + self.graphs[bs] = graph + self.output_buffers[bs] = output_buffers + + def capture_one_batch_size(self, bs: int, forward: Callable): + num_tokens = bs * self.num_tokens_per_bs + + # Graph inputs + input_ids = self.input_ids[:num_tokens] + req_pool_indices = self.req_pool_indices[:bs] + seq_lens = self.seq_lens[:bs] + out_cache_loc = self.out_cache_loc[:num_tokens] + positions = self.positions[:num_tokens] + mrope_positions = self.mrope_positions[:, :bs] + self.num_token_non_padded[...] = num_tokens + + spec_info = self.get_spec_info(num_tokens) + if self.capture_hidden_mode != CaptureHiddenMode.FULL: + self.capture_hidden_mode = ( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ) + + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum().item(), + return_logprob=False, + positions=positions, + mrope_positions=mrope_positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=self.capture_hidden_mode, + num_token_non_padded=self.num_token_non_padded, + global_forward_mode=self.capture_forward_mode, + ) + + # Attention backend + self.model_runner.attn_backend.init_forward_metadata(forward_batch) + # Do infernence to avoid setting attr at runtime, e.g., + # self.attn_mha.kv_b_proj = self.kv_b_proj for full graph compile on CPU + self.model_runner.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + + # Run and capture + def run_once(): + # Clean intermediate result cache for DP attention + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + logits_output_or_pp_proxy_tensors = forward( + input_ids, + forward_batch.positions, + forward_batch, + ) + return logits_output_or_pp_proxy_tensors + + with torch.no_grad(): + for _ in range(2): + self.model_runner.tp_group.barrier() + out = run_once() + return forward, out + + def recapture_if_needed(self, forward_batch: ForwardBatch): + + # If the required capture_hidden_mode changes, we need to recapture the graph + + # These are the different factors that can influence the capture_hidden_mode + capture_hidden_mode_required_by_forward_batch = ( + forward_batch.capture_hidden_mode + ) + capture_hidden_mode_required_by_spec_info = getattr( + forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL + ) + capture_hidden_mode_required_for_returning_hidden_states = ( + CaptureHiddenMode.FULL + if self.model_runner.server_args.enable_return_hidden_states + else CaptureHiddenMode.NULL + ) + + # Determine the highest capture_hidden_mode required + # (If we have FULL, we can emulate LAST or NULL) + # (If we have LAST, we can emulate NULL) + required_capture_hidden_mode = max( + capture_hidden_mode_required_by_forward_batch, + capture_hidden_mode_required_by_spec_info, + capture_hidden_mode_required_for_returning_hidden_states, + ) + + # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture + if self.capture_hidden_mode != required_capture_hidden_mode: + self.capture_hidden_mode = required_capture_hidden_mode + self.capture() + + def replay_prepare( + self, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ): + self.recapture_if_needed(forward_batch) + + raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs + + # Pad + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] + if bs != raw_bs: + self.seq_lens.fill_(self.seq_len_fill_value) + self.out_cache_loc.zero_() + + # Common inputs + self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) + self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) + self.positions[:raw_num_token].copy_(forward_batch.positions) + + if pp_proxy_tensors: + for key in self.pp_proxy_tensors.keys(): + dim = pp_proxy_tensors[key].shape[0] + self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) + + if forward_batch.mrope_positions is not None: + self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + if enable_num_token_non_padded(self.model_runner.server_args): + self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) + if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: + forward_batch.spec_info.custom_mask = self.custom_mask + + # Store fields + self.raw_bs = raw_bs + self.raw_num_token = raw_num_token + self.bs = bs + + def replay( + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + if not skip_attn_backend_init: + self.replay_prepare(forward_batch, pp_proxy_tensors) + else: + # In speculative decoding, these two fields are still needed. + self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) + self.positions[: self.raw_num_token].copy_(forward_batch.positions) + + self.model_runner.attn_backend.init_forward_metadata(forward_batch) + output = self.graphs[self.bs]( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + if isinstance(output, LogitsProcessorOutput): + return LogitsProcessorOutput( + next_token_logits=output.next_token_logits[: self.raw_num_token], + hidden_states=( + output.hidden_states[: self.raw_num_token] + if output.hidden_states is not None + else None + ), + ) + else: + assert isinstance(output, PPProxyTensors) + return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) + + def get_spec_info(self, num_tokens: int): + spec_info = None + if self.model_runner.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_utils import EagleVerifyInput + + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen.") + else: + spec_info = EagleVerifyInput( + draft_token=None, + custom_mask=self.custom_mask, + positions=None, + retrive_index=None, + retrive_next_token=None, + retrive_next_sibling=None, + retrive_cum_len=None, + spec_steps=self.model_runner.server_args.speculative_num_steps, + topk=self.model_runner.server_args.speculative_eagle_topk, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=None, + seq_lens_cpu=None, + ) + + return spec_info + + +CPU_GRAPH_CAPTURE_FAILED_MSG = ( + "Possible solutions:\n" + "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" + "2. set --torch-compile-max-bs to a smaller value (e.g., 8)\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" +) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index bceb0759efa8..485796314d44 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -132,6 +132,9 @@ def is_cuda_graph(self): or self == ForwardMode.IDLE ) + def is_cpu_graph(self): + return self == ForwardMode.DECODE + def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b05973c812be..e3338416bc58 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -89,6 +89,7 @@ ReqToTokenPool, SWAKVPool, ) +from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner @@ -346,9 +347,12 @@ def initialize(self, min_per_gpu_memory: float): elif self.device == "npu": self.init_attention_backend() self.init_device_graphs() + elif self.device == "cpu": + self.init_attention_backend() + self.init_device_graphs("cpu") else: self.graph_runner = None - self.cuda_graph_mem_usage = 0 + self.graph_mem_usage = 0 self.init_attention_backend() # auxiliary hidden capture mode. TODO: expose this to server args? @@ -588,6 +592,11 @@ def init_torch_distributed(self): # Set local size to hint SGLang to use shared memory based AllReduce os.environ["LOCAL_SIZE"] = str(self.tp_size) torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank) + + @torch.library.register_fake("sgl_kernel::shm_allgather") + def _(data, dim): + return torch.cat([data] * self.tp_size, dim=dim) + else: logger.warning( "init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available" @@ -1590,31 +1599,38 @@ def init_double_sparsity_channel_config(self, selected_channel): .cuda() ) - def init_device_graphs(self): - """Capture cuda graphs.""" + def init_device_graphs(self, device="cuda"): + """Capture cuda/cpu graphs.""" self.graph_runner = None - self.cuda_graph_mem_usage = 0 + self.graph_mem_usage = 0 if not self.is_generation: # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models return - if self.server_args.disable_cuda_graph: + if device == "cuda" and self.server_args.disable_cuda_graph: + return + + if device == "cpu" and not self.server_args.enable_torch_compile: return tic = time.perf_counter() before_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" - ) - self.graph_runner = ( - CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self) + f"Capture {'cpu graph of torch.compile' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) + if device == "cuda": + self.graph_runner = ( + CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self) + ) + else: + assert device == "cpu", "Only cuda and cpu are supported for graph capture." + self.graph_runner = CPUGraphRunner(self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) - self.cuda_graph_mem_usage = before_mem - after_mem + self.graph_mem_usage = before_mem - after_mem logger.info( - f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. " - f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." + f"Capture {'cpu graph of torch.compile' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. " + f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." ) def init_threads_binding(self): @@ -1759,18 +1775,24 @@ def _forward_raw( reinit_attn_backend: bool = False, split_forward_count: int = 1, ) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]: - can_run_cuda_graph = bool( - forward_batch.forward_mode.is_cuda_graph() + mode_check = ( + forward_batch.forward_mode.is_cpu_graph + if self.device == "cpu" + else forward_batch.forward_mode.is_cuda_graph + ) + can_run_graph = bool( + mode_check() and self.graph_runner and self.graph_runner.can_run(forward_batch) ) - if can_run_cuda_graph: + + if can_run_graph: ret = self.graph_runner.replay( forward_batch, skip_attn_backend_init=skip_attn_backend_init, pp_proxy_tensors=pp_proxy_tensors, ) - return ret, can_run_cuda_graph + return ret, can_run_graph # For MLP sync if forward_batch.global_num_tokens_cpu is not None: @@ -1802,7 +1824,7 @@ def _forward_raw( if forward_batch.global_num_tokens_cpu is not None: forward_batch.post_forward_mlp_sync_batch(ret) - return ret, can_run_cuda_graph + return ret, can_run_graph def _preprocess_logits( self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 972d7182d817..da4f66057c86 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -339,7 +339,7 @@ def forward_batch_speculative_generation( else: with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) - logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( + logits_output, verify_output, model_worker_batch, can_run_graph = ( self.verify(batch, spec_info) ) @@ -358,7 +358,7 @@ def forward_batch_speculative_generation( verify_output.verified_id, model_worker_batch.bid, sum(verify_output.accept_length_per_req_cpu), - can_run_cuda_graph, + can_run_graph, ) def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): @@ -679,10 +679,8 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ).cpu() # Forward - logits_output, _, can_run_cuda_graph = ( - self.target_worker.forward_batch_generation( - model_worker_batch, skip_sample=True - ) + logits_output, _, can_run_graph = self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True ) vocab_mask = None @@ -731,7 +729,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ) batch.spec_info = res.draft_input - return logits_output, res, model_worker_batch, can_run_cuda_graph + return logits_output, res, model_worker_batch, can_run_graph def add_logprob_values( self, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 0318f3bd4a89..1d5b7bd14f5b 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -216,8 +216,16 @@ def support_triton(backend: str) -> bool: is_intel_amx_backend_available = False +try: + # move torch._C._cpu._is_amx_tile_supported() from cpu_has_amx_support + # to support torch compile + is_amx_tile_supported = torch._C._cpu._is_amx_tile_supported() +except: + is_amx_tile_supported = False + + def cpu_has_amx_support(): - return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available + return is_amx_tile_supported and is_intel_amx_backend_available def use_intel_amx_backend(layer): diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index 44257dec5e0d..872c07628a9e 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -239,7 +239,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu); m.def("l2norm_cpu(Tensor input, float eps) -> Tensor"); m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu); - m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"); + m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()"); m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu); // topk @@ -262,14 +262,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // decode m.def( - "decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, " + "decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, " "Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, " "float logit_cap) -> ()"); m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu); // extend m.def( - "extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, " + "extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, " "Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor " "extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"); m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu); @@ -305,7 +305,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant); // bmm - m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"); + m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"); m.impl("bmm_cpu", torch::kCPU, &bmm_cpu); // moe @@ -342,7 +342,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { // all reduce m.def("initialize(int size, int rank) -> ()"); - m.def("shm_allreduce(Tensor data, int reduce_op) -> ()"); + m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()"); m.impl("shm_allreduce", torch::kCPU, &shm_allreduce); m.def("shm_allgather(Tensor data, int dim) -> Tensor"); m.impl("shm_allgather", torch::kCPU, &shm_allgather); diff --git a/test/srt/test_intel_amx_attention_backend.py b/test/srt/test_intel_amx_attention_backend.py index 0b49c8af741d..b2148c633bfd 100644 --- a/test/srt/test_intel_amx_attention_backend.py +++ b/test/srt/test_intel_amx_attention_backend.py @@ -3,10 +3,12 @@ python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu """ +import copy +import os import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -74,6 +76,73 @@ def test_mmlu(self): finally: kill_process_tree(process.pid) + def test_latency_torch_compile_cpu(self): + prefill_latency, decode_throughput, decode_latency = run_bench_one_batch( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + [ + "--attention-backend", + "intel_amx", + "--mem-fraction-static", + "0.05", + "--disable-radix", + "--trust-remote-code", + "--batch-size", + "1", + "--enable-torch-compile", + "--torch-compile-max-bs", + "1", + ], + ) + + print(f"{prefill_latency=}") + print(f"{decode_throughput=}") + print(f"{decode_latency=}") + + if is_in_ci(): + self.assertGreater(decode_throughput, 10) + + def test_mmlu_torch_compile_cpu(self): + model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + cpu_ids_by_node = get_cpu_ids_by_node() + n_numa_node = len(cpu_ids_by_node) + env = copy.deepcopy(os.environ) + env["SGLANG_CPU_OMP_THREADS_BIND"] = "all" + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "intel_amx", + "--mem-fraction-static", + "0.05", + "--disable-radix", + "--trust-remote-code", + "--disable-overlap-schedule", + "--enable-torch-compile", + "--torch-compile-max-bs", + "4", + "--tp", + f"{n_numa_node}", + ], + env=env, + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + finally: + kill_process_tree(process.pid) + if __name__ == "__main__": unittest.main() From f8d5ab2987201ec9c11fdbb140058f14a41898d0 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 19 Aug 2025 11:45:26 +0000 Subject: [PATCH 02/13] modify logs and init_device_graphs --- .../srt/managers/scheduler_metrics_mixin.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 33 +++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 34728362e977..7d072b0175aa 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -185,7 +185,7 @@ def log_decode_stats( msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " msg += ( - f"{'cpu graph of torch.compile' if self.device == 'cpu' else 'cuda graph'}: {can_run_graph}, " + f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_graph}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"#queue-req: {len(self.waiting_queue)}, " ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index e3338416bc58..36b0a8987f63 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -349,7 +349,7 @@ def initialize(self, min_per_gpu_memory: float): self.init_device_graphs() elif self.device == "cpu": self.init_attention_backend() - self.init_device_graphs("cpu") + self.init_device_graphs() else: self.graph_runner = None self.graph_mem_usage = 0 @@ -1599,8 +1599,8 @@ def init_double_sparsity_channel_config(self, selected_channel): .cuda() ) - def init_device_graphs(self, device="cuda"): - """Capture cuda/cpu graphs.""" + def init_device_graphs(self): + """Capture cuda/npu/cpu graphs.""" self.graph_runner = None self.graph_mem_usage = 0 @@ -1608,28 +1608,33 @@ def init_device_graphs(self, device="cuda"): # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models return - if device == "cuda" and self.server_args.disable_cuda_graph: + if self.device in ["cuda", "npu"] and self.server_args.disable_cuda_graph: return - if device == "cpu" and not self.server_args.enable_torch_compile: + if self.device == "cpu" and not self.server_args.enable_torch_compile: return tic = time.perf_counter() before_mem = get_available_gpu_memory(self.device, self.gpu_id) logger.info( - f"Capture {'cpu graph of torch.compile' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" + f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - if device == "cuda": - self.graph_runner = ( - CudaGraphRunner(self) if not _is_npu else NPUGraphRunner(self) - ) - else: - assert device == "cpu", "Only cuda and cpu are supported for graph capture." - self.graph_runner = CPUGraphRunner(self) + graph_runners = { + "cuda": CudaGraphRunner, + "npu": CudaGraphRunner if not _is_npu else NPUGraphRunner, + "cpu": CPUGraphRunner, + } + assert self.device in [ + "cuda", + "npu", + "cpu", + ], "Only cuda, npu and cpu are supported for graph capture." + self.graph_runner = graph_runners[self.device](self) + after_mem = get_available_gpu_memory(self.device, self.gpu_id) self.graph_mem_usage = before_mem - after_mem logger.info( - f"Capture {'cpu graph of torch.compile' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. " + f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. " f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." ) From 21a485f0c47a1833f58b2da3c387e2f1805520a8 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 19 Aug 2025 16:29:34 +0000 Subject: [PATCH 03/13] remove padding --- .../srt/model_executor/cpu_graph_runner.py | 80 +++---------------- 1 file changed, 9 insertions(+), 71 deletions(-) diff --git a/python/sglang/srt/model_executor/cpu_graph_runner.py b/python/sglang/srt/model_executor/cpu_graph_runner.py index 14dc9de89d11..121e54d44023 100644 --- a/python/sglang/srt/model_executor/cpu_graph_runner.py +++ b/python/sglang/srt/model_executor/cpu_graph_runner.py @@ -17,7 +17,6 @@ from __future__ import annotations -import bisect import logging from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Optional, Union @@ -34,7 +33,6 @@ ForwardBatch, ForwardMode, PPProxyTensors, - enable_num_token_non_padded, ) from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -96,7 +94,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): server_args = model_runner.server_args # cpu torch compile only speeds up decoding by # reducing python overhead when bs is small - capture_bs = list(range(1, 9)) + list(range(10, 17, 2)) + capture_bs = list(range(1, 17)) capture_bs = [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] capture_bs = list(sorted(set(capture_bs))) @@ -446,13 +444,7 @@ def __init__(self, model_runner: ModelRunner): ) def can_run(self, forward_batch: ForwardBatch): - cpu_graph_bs = forward_batch.batch_size - - is_bs_supported = ( - cpu_graph_bs in self.graphs - if self.disable_padding - else cpu_graph_bs <= self.max_bs - ) + is_bs_supported = forward_batch.batch_size in self.graphs requested_capture_hidden_mode = max( forward_batch.capture_hidden_mode, @@ -593,78 +585,24 @@ def recapture_if_needed(self, forward_batch: ForwardBatch): self.capture_hidden_mode = required_capture_hidden_mode self.capture() - def replay_prepare( - self, - forward_batch: ForwardBatch, - pp_proxy_tensors: Optional[PPProxyTensors] = None, - ): - self.recapture_if_needed(forward_batch) - - raw_bs = forward_batch.batch_size - raw_num_token = raw_bs * self.num_tokens_per_bs - - # Pad - index = bisect.bisect_left(self.capture_bs, raw_bs) - bs = self.capture_bs[index] - if bs != raw_bs: - self.seq_lens.fill_(self.seq_len_fill_value) - self.out_cache_loc.zero_() - - # Common inputs - self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) - self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) - self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) - self.positions[:raw_num_token].copy_(forward_batch.positions) - - if pp_proxy_tensors: - for key in self.pp_proxy_tensors.keys(): - dim = pp_proxy_tensors[key].shape[0] - self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) - - if forward_batch.mrope_positions is not None: - self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) - if enable_num_token_non_padded(self.model_runner.server_args): - self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) - if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: - forward_batch.spec_info.custom_mask = self.custom_mask - - # Store fields - self.raw_bs = raw_bs - self.raw_num_token = raw_num_token - self.bs = bs - + # TODO add padding support for CPUGraphRunner def replay( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: - if not skip_attn_backend_init: - self.replay_prepare(forward_batch, pp_proxy_tensors) - else: - # In speculative decoding, these two fields are still needed. - self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) - self.positions[: self.raw_num_token].copy_(forward_batch.positions) - + assert ( + pp_proxy_tensors is None + ), "PPProxyTensors is not supported in CPUGraphRunner yet." + self.recapture_if_needed(forward_batch) self.model_runner.attn_backend.init_forward_metadata(forward_batch) - output = self.graphs[self.bs]( + output = self.graphs[forward_batch.batch_size]( forward_batch.input_ids, forward_batch.positions, forward_batch, ) - if isinstance(output, LogitsProcessorOutput): - return LogitsProcessorOutput( - next_token_logits=output.next_token_logits[: self.raw_num_token], - hidden_states=( - output.hidden_states[: self.raw_num_token] - if output.hidden_states is not None - else None - ), - ) - else: - assert isinstance(output, PPProxyTensors) - return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) + return output def get_spec_info(self, num_tokens: int): spec_info = None From 954f5ab3948dbce76ab0b44937c8a49d1fa85ad4 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 19 Aug 2025 21:24:47 +0000 Subject: [PATCH 04/13] use defaultdict --- .../sglang/srt/model_executor/model_runner.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 36b0a8987f63..4022d864ccdf 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -20,6 +20,7 @@ import logging import os import time +from collections import defaultdict from dataclasses import dataclass from typing import List, Optional, Tuple, Union @@ -1619,16 +1620,13 @@ def init_device_graphs(self): logger.info( f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB" ) - graph_runners = { - "cuda": CudaGraphRunner, - "npu": CudaGraphRunner if not _is_npu else NPUGraphRunner, - "cpu": CPUGraphRunner, - } - assert self.device in [ - "cuda", - "npu", - "cpu", - ], "Only cuda, npu and cpu are supported for graph capture." + graph_runners = defaultdict( + lambda: CudaGraphRunner, + { + "npu": CudaGraphRunner if not _is_npu else NPUGraphRunner, + "cpu": CPUGraphRunner, + }, + ) self.graph_runner = graph_runners[self.device](self) after_mem = get_available_gpu_memory(self.device, self.gpu_id) From d7917fef729d05984eb51e5606260a498e9f0eb5 Mon Sep 17 00:00:00 2001 From: JieXin Liang Date: Wed, 20 Aug 2025 14:31:39 +0800 Subject: [PATCH 05/13] Update python/sglang/srt/model_executor/model_runner.py --- python/sglang/srt/model_executor/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4f9e8a4407ab..64efacf9af78 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1626,7 +1626,7 @@ def init_device_graphs(self): lambda: CudaGraphRunner, { "cpu": CPUGraphRunner, - "npu": CudaGraphRunner if not _is_npu else NPUGraphRunner, + "npu": NPUGraphRunner, }, ) self.graph_runner = graph_runners[self.device](self) From d5736f375d88fd83c4329775f13697010c264499 Mon Sep 17 00:00:00 2001 From: CaoE Date: Fri, 22 Aug 2025 13:48:59 +0000 Subject: [PATCH 06/13] increase timeout-minutes --- .github/workflows/pr-test-xeon.yml | 2 +- test/srt/test_intel_amx_attention_backend.py | 28 ++++---------------- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/.github/workflows/pr-test-xeon.yml b/.github/workflows/pr-test-xeon.yml index c64452a70cbe..a228771efc63 100644 --- a/.github/workflows/pr-test-xeon.yml +++ b/.github/workflows/pr-test-xeon.yml @@ -70,7 +70,7 @@ jobs: - name: Run unit tests if: steps.check_amx.outcome == 'success' - timeout-minutes: 30 + timeout-minutes: 60 run: | docker exec -w /sglang-checkout/ ci_sglang_xeon \ bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu" diff --git a/test/srt/test_intel_amx_attention_backend.py b/test/srt/test_intel_amx_attention_backend.py index 3cfd6534bbd4..14b6dfec347d 100644 --- a/test/srt/test_intel_amx_attention_backend.py +++ b/test/srt/test_intel_amx_attention_backend.py @@ -129,30 +129,12 @@ def test_mmlu(self): finally: kill_process_tree(process.pid) + @intel_amx_benchmark( + extra_args=["--enable-torch-compile", "--torch-compile-max-bs", "4"], + min_throughput=40, + ) def test_latency_torch_compile_cpu(self): - prefill_latency, decode_throughput, decode_latency = run_bench_one_batch( - DEFAULT_MLA_MODEL_NAME_FOR_TEST, - [ - "--attention-backend", - "intel_amx", - "--mem-fraction-static", - "0.05", - "--disable-radix", - "--trust-remote-code", - "--batch-size", - "1", - "--enable-torch-compile", - "--torch-compile-max-bs", - "1", - ], - ) - - print(f"{prefill_latency=}") - print(f"{decode_throughput=}") - print(f"{decode_latency=}") - - if is_in_ci(): - self.assertGreater(decode_throughput, 10) + return DEFAULT_MODEL_NAME_FOR_TEST def test_mmlu_torch_compile_cpu(self): model = DEFAULT_MLA_MODEL_NAME_FOR_TEST From ce41f84db1c3e65274ab96d178e68cc6bd735f88 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 25 Aug 2025 09:07:37 +0000 Subject: [PATCH 07/13] modify extra_args for compile test --- test/srt/test_intel_amx_attention_backend.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/srt/test_intel_amx_attention_backend.py b/test/srt/test_intel_amx_attention_backend.py index 14b6dfec347d..273acf25212d 100644 --- a/test/srt/test_intel_amx_attention_backend.py +++ b/test/srt/test_intel_amx_attention_backend.py @@ -130,11 +130,17 @@ def test_mmlu(self): kill_process_tree(process.pid) @intel_amx_benchmark( - extra_args=["--enable-torch-compile", "--torch-compile-max-bs", "4"], - min_throughput=40, + extra_args=[ + "--mem-fraction-static", + "0.05", + "--enable-torch-compile", + "--torch-compile-max-bs", + "4", + ], + min_throughput=10, ) def test_latency_torch_compile_cpu(self): - return DEFAULT_MODEL_NAME_FOR_TEST + return DEFAULT_MLA_MODEL_NAME_FOR_TEST def test_mmlu_torch_compile_cpu(self): model = DEFAULT_MLA_MODEL_NAME_FOR_TEST From 6f7d22eddb53085967ccf358aca0fdaca89b2978 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 25 Aug 2025 14:29:21 +0000 Subject: [PATCH 08/13] move cpu graph tests to test_cpu_graph.py --- test/srt/run_suite.py | 1 + test/srt/test_cpu_graph.py | 84 ++++++++++++++++++++ test/srt/test_intel_amx_attention_backend.py | 59 +------------- 3 files changed, 86 insertions(+), 58 deletions(-) create mode 100644 test/srt/test_cpu_graph.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 4c98dc585343..88623a6ec68a 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -262,6 +262,7 @@ class TestFile: TestFile("cpu/test_shared_expert.py"), TestFile("cpu/test_topk.py"), TestFile("test_intel_amx_attention_backend.py"), + TestFile("test_cpu_graph.py"), ], } diff --git a/test/srt/test_cpu_graph.py b/test/srt/test_cpu_graph.py new file mode 100644 index 000000000000..f16284ceec1b --- /dev/null +++ b/test/srt/test_cpu_graph.py @@ -0,0 +1,84 @@ +""" +Usage: +python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu +""" + +import copy +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.test_intel_amx_attention_backend import intel_amx_benchmark +from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + is_in_ci, + popen_launch_server, +) + + +class TestCPUGraph(CustomTestCase): + + @intel_amx_benchmark( + extra_args=[ + "--mem-fraction-static", + "0.05", + "--enable-torch-compile", + "--torch-compile-max-bs", + "4", + ], + min_throughput=10, + ) + def test_latency_torch_compile_cpu(self): + return DEFAULT_MLA_MODEL_NAME_FOR_TEST + + def test_mmlu_torch_compile_cpu(self): + model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + cpu_ids_by_node = get_cpu_ids_by_node() + n_numa_node = len(cpu_ids_by_node) + env = copy.deepcopy(os.environ) + env["SGLANG_CPU_OMP_THREADS_BIND"] = "all" + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--attention-backend", + "intel_amx", + "--mem-fraction-static", + "0.05", + "--disable-radix", + "--trust-remote-code", + "--disable-overlap-schedule", + "--enable-torch-compile", + "--torch-compile-max-bs", + "4", + "--tp", + f"{n_numa_node}", + ], + env=env, + ) + + try: + args = SimpleNamespace( + base_url=base_url, + model=model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + if is_in_ci(): + self.assertGreater(metrics["score"], 0.45) + finally: + kill_process_tree(process.pid) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_intel_amx_attention_backend.py b/test/srt/test_intel_amx_attention_backend.py index 273acf25212d..20b7e0edac82 100644 --- a/test/srt/test_intel_amx_attention_backend.py +++ b/test/srt/test_intel_amx_attention_backend.py @@ -3,13 +3,11 @@ python3 -m unittest test_intel_amx_attention_backend.TestIntelAMXAttnBackend.test_mmlu """ -import copy -import os import unittest from functools import wraps from types import SimpleNamespace -from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -129,61 +127,6 @@ def test_mmlu(self): finally: kill_process_tree(process.pid) - @intel_amx_benchmark( - extra_args=[ - "--mem-fraction-static", - "0.05", - "--enable-torch-compile", - "--torch-compile-max-bs", - "4", - ], - min_throughput=10, - ) - def test_latency_torch_compile_cpu(self): - return DEFAULT_MLA_MODEL_NAME_FOR_TEST - - def test_mmlu_torch_compile_cpu(self): - model = DEFAULT_MLA_MODEL_NAME_FOR_TEST - base_url = DEFAULT_URL_FOR_TEST - cpu_ids_by_node = get_cpu_ids_by_node() - n_numa_node = len(cpu_ids_by_node) - env = copy.deepcopy(os.environ) - env["SGLANG_CPU_OMP_THREADS_BIND"] = "all" - process = popen_launch_server( - model, - base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--attention-backend", - "intel_amx", - "--mem-fraction-static", - "0.05", - "--disable-radix", - "--trust-remote-code", - "--disable-overlap-schedule", - "--enable-torch-compile", - "--torch-compile-max-bs", - "4", - "--tp", - f"{n_numa_node}", - ], - env=env, - ) - - try: - args = SimpleNamespace( - base_url=base_url, - model=model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.5) - finally: - kill_process_tree(process.pid) - if __name__ == "__main__": unittest.main() From 159a103a0519d8dad00758ab291737b3e5e37795 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 25 Aug 2025 14:32:58 +0000 Subject: [PATCH 09/13] modify timeout-minutes --- .github/workflows/pr-test-xeon.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-test-xeon.yml b/.github/workflows/pr-test-xeon.yml index a228771efc63..d04af436cc35 100644 --- a/.github/workflows/pr-test-xeon.yml +++ b/.github/workflows/pr-test-xeon.yml @@ -70,7 +70,7 @@ jobs: - name: Run unit tests if: steps.check_amx.outcome == 'success' - timeout-minutes: 60 + timeout-minutes: 45 run: | docker exec -w /sglang-checkout/ ci_sglang_xeon \ bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu" From d6109c53834c02c25f6cee1c3b913c66e9f16dd8 Mon Sep 17 00:00:00 2001 From: CaoE Date: Mon, 25 Aug 2025 16:35:33 +0000 Subject: [PATCH 10/13] fix merge main --- python/sglang/srt/model_executor/cpu_graph_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/cpu_graph_runner.py b/python/sglang/srt/model_executor/cpu_graph_runner.py index 121e54d44023..bc1e5c5b8774 100644 --- a/python/sglang/srt/model_executor/cpu_graph_runner.py +++ b/python/sglang/srt/model_executor/cpu_graph_runner.py @@ -37,7 +37,7 @@ from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( - rank0_log, + log_info_on_rank0, require_attn_tp_gather, require_gathered_buffer, require_mlp_sync, @@ -402,7 +402,7 @@ def __init__(self, model_runner: ModelRunner): # Batch sizes to capture self.capture_bs = get_batch_sizes_to_capture(model_runner) - rank0_log(f"Capture cpu graph bs {self.capture_bs}") + log_info_on_rank0(logger, f"Capture cpu graph bs {self.capture_bs}") # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs From a4066c4a3d70b39a88fe033b5fdab31e246897d0 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 26 Aug 2025 11:54:34 +0000 Subject: [PATCH 11/13] fix test_cpu_graph --- test/srt/test_cpu_graph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/srt/test_cpu_graph.py b/test/srt/test_cpu_graph.py index f16284ceec1b..f8aab49725bb 100644 --- a/test/srt/test_cpu_graph.py +++ b/test/srt/test_cpu_graph.py @@ -8,7 +8,8 @@ import unittest from types import SimpleNamespace -from sglang.srt.test_intel_amx_attention_backend import intel_amx_benchmark +from test_intel_amx_attention_backend import intel_amx_benchmark + from sglang.srt.utils import get_cpu_ids_by_node, kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( From 9dedc87a130b9b3b92dfbb122744bfb3d206de52 Mon Sep 17 00:00:00 2001 From: CaoE Date: Tue, 26 Aug 2025 16:03:03 +0000 Subject: [PATCH 12/13] reduce bs to shorten test time --- .github/workflows/pr-test-xeon.yml | 2 +- test/srt/test_cpu_graph.py | 8 +++++--- test/srt/test_intel_amx_attention_backend.py | 17 ++++++++++------- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/.github/workflows/pr-test-xeon.yml b/.github/workflows/pr-test-xeon.yml index d04af436cc35..70e70c3a377b 100644 --- a/.github/workflows/pr-test-xeon.yml +++ b/.github/workflows/pr-test-xeon.yml @@ -70,7 +70,7 @@ jobs: - name: Run unit tests if: steps.check_amx.outcome == 'success' - timeout-minutes: 45 + timeout-minutes: 36 run: | docker exec -w /sglang-checkout/ ci_sglang_xeon \ bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu" diff --git a/test/srt/test_cpu_graph.py b/test/srt/test_cpu_graph.py index f8aab49725bb..4e3c405393f2 100644 --- a/test/srt/test_cpu_graph.py +++ b/test/srt/test_cpu_graph.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu +python3 -m unittest test_cpu_graph.TestCPUGraph.test_mmlu_torch_compile_cpu """ import copy @@ -26,11 +26,13 @@ class TestCPUGraph(CustomTestCase): @intel_amx_benchmark( extra_args=[ + "--batch-size", + "1", "--mem-fraction-static", "0.05", "--enable-torch-compile", "--torch-compile-max-bs", - "4", + "1", ], min_throughput=10, ) @@ -58,7 +60,7 @@ def test_mmlu_torch_compile_cpu(self): "--disable-overlap-schedule", "--enable-torch-compile", "--torch-compile-max-bs", - "4", + "1", "--tp", f"{n_numa_node}", ], diff --git a/test/srt/test_intel_amx_attention_backend.py b/test/srt/test_intel_amx_attention_backend.py index 20b7e0edac82..22f7057ce2fc 100644 --- a/test/srt/test_intel_amx_attention_backend.py +++ b/test/srt/test_intel_amx_attention_backend.py @@ -34,8 +34,6 @@ def wrapper(self): "intel_amx", "--disable-radix", "--trust-remote-code", - "--batch-size", - "4", ] full_args = common_args + (extra_args or []) @@ -59,28 +57,33 @@ def wrapper(self): class TestIntelAMXAttnBackend(CustomTestCase): - @intel_amx_benchmark(min_throughput=10) + @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=10) def test_latency_mla_model(self): return DEFAULT_MLA_MODEL_NAME_FOR_TEST - @intel_amx_benchmark(min_throughput=40) + @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=40) def test_latency_default_model(self): return DEFAULT_MODEL_NAME_FOR_TEST - @intel_amx_benchmark(min_throughput=150) + @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=150) def test_latency_fp8_qwen(self): return DEFAULT_MODEL_NAME_FOR_TEST_QWEN_FP8 - @intel_amx_benchmark(min_throughput=50) + @intel_amx_benchmark(extra_args=["--batch-size", "4"], min_throughput=50) def test_latency_fp8_moe_model(self): return DEFAULT_MODEL_NAME_FOR_TEST_FP8_WITH_MOE - @intel_amx_benchmark(extra_args=["--quantization", "w8a8_int8"], min_throughput=100) + @intel_amx_benchmark( + extra_args=["--batch-size", "4", "--quantization", "w8a8_int8"], + min_throughput=100, + ) def test_latency_w8a8_default_model(self): return DEFAULT_MODEL_NAME_FOR_TEST_W8A8 @intel_amx_benchmark( extra_args=[ + "--batch-size", + "4", "--quantization", "w8a8_int8", "--mem-fraction-static", From b99506f6acf46853cd2b25ce75683818c1ac91c1 Mon Sep 17 00:00:00 2001 From: CaoE Date: Fri, 29 Aug 2025 09:45:10 +0000 Subject: [PATCH 13/13] change can_run_graph back to can_run_cuda_graph --- python/sglang/srt/disaggregation/prefill.py | 2 +- python/sglang/srt/managers/scheduler.py | 12 ++++++------ .../sglang/srt/managers/scheduler_metrics_mixin.py | 4 ++-- .../srt/managers/scheduler_output_processor_mixin.py | 8 ++++---- python/sglang/srt/managers/tp_worker.py | 8 ++++---- .../sglang/srt/managers/tp_worker_overlap_thread.py | 8 ++++---- python/sglang/srt/speculative/eagle_worker.py | 12 +++++++----- 7 files changed, 28 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/disaggregation/prefill.py b/python/sglang/srt/disaggregation/prefill.py index 495aad8d1a3e..0631976183bc 100644 --- a/python/sglang/srt/disaggregation/prefill.py +++ b/python/sglang/srt/disaggregation/prefill.py @@ -778,7 +778,7 @@ def event_loop_pp_disagg_prefill(self: Scheduler): extend_input_len_per_req=None, extend_logprob_start_len_per_req=None, bid=bids[next_mb_id], - can_run_graph=result.can_run_graph, + can_run_cuda_graph=result.can_run_cuda_graph, ) self.process_batch_result_disagg_prefill( mbs[next_mb_id], output_result diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 017ac8d76986..3e631f1ab0ee 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -181,7 +181,7 @@ class GenerationBatchResult: extend_input_len_per_req: List[int] extend_logprob_start_len_per_req: List[int] bid: int - can_run_graph: bool + can_run_cuda_graph: bool @dataclass @@ -929,7 +929,7 @@ def event_loop_pp(self): "extend_logprob_start_len_per_req", None ), bid=bids[next_mb_id], - can_run_graph=result.can_run_graph, + can_run_cuda_graph=result.can_run_cuda_graph, ) self.process_batch_result(mbs[next_mb_id], output_result) last_mbs[next_mb_id] = mbs[next_mb_id] @@ -1778,11 +1778,11 @@ def run_batch( model_worker_batch.hicache_consumer_index ) if self.pp_group.is_last_rank: - logits_output, next_token_ids, can_run_graph = ( + logits_output, next_token_ids, can_run_cuda_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) ) else: - pp_hidden_states_proxy_tensors, _, can_run_graph = ( + pp_hidden_states_proxy_tensors, _, can_run_cuda_graph = ( self.tp_worker.forward_batch_generation(model_worker_batch) ) bid = model_worker_batch.bid @@ -1792,7 +1792,7 @@ def run_batch( next_token_ids, bid, num_accepted_tokens, - can_run_graph, + can_run_cuda_graph, ) = self.draft_worker.forward_batch_speculative_generation(batch) bs = batch.batch_size() self.spec_num_total_accepted_tokens += num_accepted_tokens + bs @@ -1827,7 +1827,7 @@ def run_batch( extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, bid=bid, - can_run_graph=can_run_graph, + can_run_cuda_graph=can_run_cuda_graph, ) else: # embedding or reward model model_worker_batch = batch.get_model_worker_batch() diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index 6c2432a6a6d7..83c0e61b32c0 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -138,7 +138,7 @@ def log_prefill_stats( self._publish_kv_events() def log_decode_stats( - self, can_run_graph: bool, running_batch: ScheduleBatch = None + self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None ): batch = running_batch or self.running_batch @@ -193,7 +193,7 @@ def log_decode_stats( msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, " msg += ( - f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_graph}, " + f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, " f"gen throughput (token/s): {self.last_gen_throughput:.2f}, " f"#queue-req: {len(self.waiting_queue)}, " ) diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index 758654cdfc69..a86899f6e79b 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -197,15 +197,15 @@ def process_batch_result_decode( result: GenerationBatchResult, launch_done: Optional[threading.Event] = None, ): - logits_output, next_token_ids, can_run_graph = ( + logits_output, next_token_ids, can_run_cuda_graph = ( result.logits_output, result.next_token_ids, - result.can_run_graph, + result.can_run_cuda_graph, ) self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: - logits_output, next_token_ids, can_run_graph = ( + logits_output, next_token_ids, can_run_cuda_graph = ( self.tp_worker.resolve_last_batch_result(launch_done) ) next_token_logprobs = logits_output.next_token_logprobs @@ -293,7 +293,7 @@ def process_batch_result_decode( self.current_scheduler_metrics_enabled() and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): - self.log_decode_stats(can_run_graph, running_batch=batch) + self.log_decode_stats(can_run_cuda_graph, running_batch=batch) def add_input_logprob_return_values( self: Scheduler, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b4a0fb5c77f3..968be171dd62 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -236,7 +236,7 @@ def forward_batch_generation( ) if self.pp_group.is_last_rank: - logits_output, can_run_graph = self.model_runner.forward( + logits_output, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors ) if launch_done is not None: @@ -249,13 +249,13 @@ def forward_batch_generation( logits_output, model_worker_batch ) - return logits_output, next_token_ids, can_run_graph + return logits_output, next_token_ids, can_run_cuda_graph else: - pp_proxy_tensors, can_run_graph = self.model_runner.forward( + pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward( forward_batch, pp_proxy_tensors=pp_proxy_tensors, ) - return pp_proxy_tensors.tensors, None, can_run_graph + return pp_proxy_tensors.tensors, None, can_run_cuda_graph def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index c4ed1bcc1993..674a941955cd 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -172,7 +172,7 @@ def forward_thread_func_(self): # update the consumer index of hicache to the running batch self.set_hicache_consumer(model_worker_batch.hicache_consumer_index) # Run forward - logits_output, next_token_ids, can_run_graph = ( + logits_output, next_token_ids, can_run_cuda_graph = ( self.worker.forward_batch_generation( model_worker_batch, model_worker_batch.launch_done ) @@ -201,7 +201,7 @@ def forward_thread_func_(self): copy_done.record() self.output_queue.put( - (copy_done, logits_output, next_token_ids, can_run_graph) + (copy_done, logits_output, next_token_ids, can_run_cuda_graph) ) def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None): @@ -209,7 +209,7 @@ def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = Non This function is called to resolve the last batch result and wait for the current batch to be launched. Used in overlap mode. """ - copy_done, logits_output, next_token_ids, can_run_graph = ( + copy_done, logits_output, next_token_ids, can_run_cuda_graph = ( self.output_queue.get() ) @@ -226,7 +226,7 @@ def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = Non logits_output.input_token_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() - return logits_output, next_token_ids, can_run_graph + return logits_output, next_token_ids, can_run_cuda_graph def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 9a8dd298445f..4829fc83ede8 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -375,7 +375,7 @@ def forward_batch_speculative_generation( else: with self.draft_tp_context(self.draft_model_runner.tp_group): spec_info = self.draft(batch) - logits_output, verify_output, model_worker_batch, can_run_graph = ( + logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( self.verify(batch, spec_info) ) @@ -394,7 +394,7 @@ def forward_batch_speculative_generation( verify_output.verified_id, model_worker_batch.bid, sum(verify_output.accept_length_per_req_cpu), - can_run_graph, + can_run_cuda_graph, ) def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch): @@ -715,8 +715,10 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ).cpu() # Forward - logits_output, _, can_run_graph = self.target_worker.forward_batch_generation( - model_worker_batch, skip_sample=True + logits_output, _, can_run_cuda_graph = ( + self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) ) vocab_mask = None @@ -765,7 +767,7 @@ def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): ) batch.spec_info = res.draft_input - return logits_output, res, model_worker_batch, can_run_graph + return logits_output, res, model_worker_batch, can_run_cuda_graph def add_logprob_values( self,