Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
68a3c1d
add graph runner support with torch compile on CPU
CaoE Aug 18, 2025
f8d5ab2
modify logs and init_device_graphs
CaoE Aug 19, 2025
21a485f
remove padding
CaoE Aug 19, 2025
954f5ab
use defaultdict
CaoE Aug 19, 2025
7449778
Merge branch 'main' into cpu_compile
CaoE Aug 20, 2025
2e0e9a7
Merge branch 'main' into cpu_compile
CaoE Aug 20, 2025
d7917fe
Update python/sglang/srt/model_executor/model_runner.py
Alcanderian Aug 20, 2025
f317bf7
Merge branch 'main' into cpu_compile
CaoE Aug 21, 2025
513a2ef
Merge branch 'main' into cpu_compile
CaoE Aug 21, 2025
20f3fbc
Merge branch 'main' into cpu_compile
CaoE Aug 21, 2025
1d23804
Merge branch 'main' into cpu_compile
CaoE Aug 21, 2025
d7264d8
Merge branch 'main' into cpu_compile
CaoE Aug 21, 2025
d3643ed
Merge branch 'main' into cpu_compile
CaoE Aug 21, 2025
a84d088
Merge branch 'main' into cpu_compile
CaoE Aug 21, 2025
28e5047
Merge branch 'main' into cpu_compile
CaoE Aug 22, 2025
d5736f3
increase timeout-minutes
CaoE Aug 22, 2025
82981c4
Merge branch 'main' into cpu_compile
CaoE Aug 22, 2025
ce41f84
modify extra_args for compile test
CaoE Aug 25, 2025
65ea131
Merge branch 'main' into cpu_compile
CaoE Aug 25, 2025
6f7d22e
move cpu graph tests to test_cpu_graph.py
CaoE Aug 25, 2025
159a103
modify timeout-minutes
CaoE Aug 25, 2025
3a52dcc
Merge branch 'main' into cpu_compile
CaoE Aug 25, 2025
d6109c5
fix merge main
CaoE Aug 25, 2025
1c67857
Merge branch 'main' into cpu_compile
CaoE Aug 25, 2025
d04bdf2
Merge branch 'main' into cpu_compile
CaoE Aug 26, 2025
a4066c4
fix test_cpu_graph
CaoE Aug 26, 2025
9dedc87
reduce bs to shorten test time
CaoE Aug 26, 2025
fc4118a
Merge branch 'main' into cpu_compile
CaoE Aug 26, 2025
b99506f
change can_run_graph back to can_run_cuda_graph
CaoE Aug 29, 2025
3f2f32a
Merge branch 'main' into cpu_compile
CaoE Aug 29, 2025
fea6858
Merge branch 'main' into cpu_compile
CaoE Aug 29, 2025
8f02f28
Merge branch 'main' into cpu_compile
CaoE Sep 1, 2025
fdfa7f9
Merge branch 'main' into cpu_compile
zhyncs Sep 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-test-xeon.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:

- name: Run unit tests
if: steps.check_amx.outcome == 'success'
timeout-minutes: 30
timeout-minutes: 36
run: |
docker exec -w /sglang-checkout/ ci_sglang_xeon \
bash -c "cd ./test/srt && python3 run_suite.py --suite per-commit-cpu"
Expand Down
7 changes: 6 additions & 1 deletion docs/platforms/cpu_server.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,12 @@ Notes:
export SGLANG_CPU_OMP_THREADS_BIND="0-39|43-82|86-125|128-167|171-210|214-253"
```

3. A warmup step is automatically triggered when the service is started.
3. For optimizing decoding with torch.compile, please add the flag `--enable-torch-compile`.
To specify the maximum batch size when using torch compile, set the flag `--torch-compile-max-bs`.
For example, `--enable-torch-compile --torch-compile-max-bs 4` means using torch compile and setting the
maximum batch size to 4.

4. A warmup step is automatically triggered when the service is started.
The server is ready when you see the log `The server is fired up and ready to roll!`.

## Benchmarking with Requests
Expand Down
7 changes: 4 additions & 3 deletions python/sglang/srt/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class GraphCaptureContext:

TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])

# use int value instead of ReduceOp.SUM to support torch compile
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)


def _split_tensor_dict(
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
Expand Down Expand Up @@ -487,9 +490,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:

if input_.is_cpu:
if is_shm_available(input_.dtype, self.world_size, self.local_size):
torch.ops.sgl_kernel.shm_allreduce(
input_, torch.distributed.ReduceOp.SUM
)
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/attention/intel_amx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def init_forward_metadata(self, forward_batch: ForwardBatch):
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
self.forward_metadata = (attn_logits, max_extend_len)

def get_graph_seq_len_fill_value(self):
return 1

def forward_extend(
self,
q,
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ def process_weights_after_loading(self, layer: Module) -> None:
_is_cpu_amx_available
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
layer.weight_scale_inv = torch.nn.Parameter(
layer.weight_scale_inv.data, requires_grad=False
)
return
else:
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
Expand Down
12 changes: 5 additions & 7 deletions python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
_is_cpu_amx_available
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["weight"])
return

layer.weight = Parameter(layer.weight.t(), requires_grad=False)
else:
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
Comment thread
CaoE marked this conversation as resolved.
Outdated

def create_weights(
Expand Down Expand Up @@ -472,10 +471,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
_is_cpu_amx_available
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
return

layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
else:
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
Expand Down
9 changes: 4 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def __init__(
f"max_prefill_tokens={self.max_prefill_tokens}, "
f"max_running_requests={self.max_running_requests}, "
f"context_len={self.model_config.context_len}, "
f"available_gpu_mem={avail_mem:.2f} GB"
f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
)

# Init memory pool and cache
Expand Down Expand Up @@ -2317,10 +2317,9 @@ def get_internal_state(self, recv_req: GetInternalStateReq):
"token_capacity": int(self.max_total_num_tokens),
}

if not _is_cpu:
ret["memory_usage"]["cuda_graph"] = round(
self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
)
ret["memory_usage"]["graph"] = round(
self.tp_worker.worker.model_runner.graph_mem_usage, 2
)

if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
ret["avg_spec_accept_length"] = (
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/managers/scheduler_metrics_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def log_decode_stats(
msg += f"#retracted-req: {len(self.disagg_decode_prealloc_queue.retracted_queue)}, "

msg += (
f"cuda graph: {can_run_cuda_graph}, "
f"{'cpu graph' if self.device == 'cpu' else 'cuda graph'}: {can_run_cuda_graph}, "
f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
f"#queue-req: {len(self.waiting_queue)}, "
)
Expand Down
Loading
Loading