From ac34a5eecdd9d33f5516267ae06af8120430be17 Mon Sep 17 00:00:00 2001 From: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> Date: Tue, 2 Dec 2025 06:29:24 +0000 Subject: [PATCH 1/6] add new eagle3 test cases Signed-off-by: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> --- tests/examples/llm_ptq/test_deploy.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/examples/llm_ptq/test_deploy.py b/tests/examples/llm_ptq/test_deploy.py index 5467714d6..26ed313c9 100644 --- a/tests/examples/llm_ptq/test_deploy.py +++ b/tests/examples/llm_ptq/test_deploy.py @@ -374,6 +374,12 @@ def test_kimi(command): tensor_parallel_size=4, mini_sm=89, ), + *ModelDeployerList( + model_id="nvidia/Llama-3_1-Nemotron-Ultra-253B-v1-FP8", + backend=("vllm",), + tensor_parallel_size=8, + mini_sm=89, + ), ], ids=idfn, ) @@ -457,6 +463,20 @@ def test_medusa(command): tensor_parallel_size=8, mini_sm=89, ), + *ModelDeployerList( + base_model="openai/gpt-oss-120b", + model_id="nvidia/gpt-oss-120b-Eagle3-v2", + backend=("trtllm", "sglang"), + tensor_parallel_size=8, + mini_sm=89, + ), + *ModelDeployerList( + base_model="nvidia/Llama-3.3-70B-Instruct-FP8", + model_id="nvidia/Llama-3.3-70B-Instruct-Eagle3", + backend=("trtllm", "sglang"), + tensor_parallel_size=8, + mini_sm=89, + ), ], ids=idfn, ) From ce806e65216286963590993bc091bf61ba6fcf73 Mon Sep 17 00:00:00 2001 From: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> Date: Wed, 3 Dec 2025 01:45:42 +0000 Subject: [PATCH 2/6] fix thread leak issue Signed-off-by: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> --- tests/_test_utils/deploy_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/_test_utils/deploy_utils.py b/tests/_test_utils/deploy_utils.py index 85a97b616..8930c58d3 100644 --- a/tests/_test_utils/deploy_utils.py +++ b/tests/_test_utils/deploy_utils.py @@ -47,6 +47,7 @@ def __init__( model_id: Path to the model tensor_parallel_size: Tensor parallel size for distributed inference mini_sm: Minimum SM (Streaming Multiprocessor) requirement for the model + attn_backend: is for TRT LLM deployment """ self.backend = backend self.model_id = model_id @@ -130,12 +131,12 @@ def _deploy_trtllm(self): ) outputs = llm.generate(COMMON_PROMPTS, sampling_params) - # Print outputs for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + del llm def _deploy_vllm(self): """Deploy a model using vLLM.""" @@ -172,6 +173,7 @@ def _deploy_vllm(self): print(f"Model: {self.model_id}") print(f"Prompt: {output.prompt!r}, Generated text: {output.outputs[0].text!r}") print("-" * 50) + del llm def _deploy_sglang(self): """Deploy a model using SGLang.""" @@ -190,6 +192,7 @@ def _deploy_sglang(self): ) print(llm.generate(["What's the age of the earth? "])) llm.shutdown() + del llm class ModelDeployerList: From 136acc2d80cfcb585933f0be48f367726a0eeb7b Mon Sep 17 00:00:00 2001 From: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> Date: Wed, 3 Dec 2025 06:24:34 +0000 Subject: [PATCH 3/6] test eagle3 in sglang Signed-off-by: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> --- tests/_test_utils/deploy_utils.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/_test_utils/deploy_utils.py b/tests/_test_utils/deploy_utils.py index 8930c58d3..3a7bd0487 100644 --- a/tests/_test_utils/deploy_utils.py +++ b/tests/_test_utils/deploy_utils.py @@ -184,12 +184,25 @@ def _deploy_sglang(self): quantization_method = "modelopt" if "FP4" in self.model_id: quantization_method = "modelopt_fp4" - llm = sgl.Engine( - model_path=self.model_id, - quantization=quantization_method, - tp_size=self.tensor_parallel_size, - trust_remote_code=True, - ) + if "eagle" in self.model_id.lower(): + llm = sgl.Engine( + model_path=self.base_model, + quantization=quantization_method, + speculative_algorithm="EAGLE3", + speculative_num_steps=3, + speculative_eagle_topk=1, + speculative_num_draft_tokens=4, + speculative_draft_model_path=self.model_id, + tp_size=self.tensor_parallel_size, + trust_remote_code=True, + ) + else: + llm = sgl.Engine( + model_path=self.model_id, + quantization=quantization_method, + tp_size=self.tensor_parallel_size, + trust_remote_code=True, + ) print(llm.generate(["What's the age of the earth? "])) llm.shutdown() del llm From 80f46b18513950a8eeb4bcbd3f275af0ded5dbb3 Mon Sep 17 00:00:00 2001 From: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> Date: Wed, 3 Dec 2025 10:17:23 +0000 Subject: [PATCH 4/6] fix oom Signed-off-by: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> --- tests/_test_utils/deploy_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/_test_utils/deploy_utils.py b/tests/_test_utils/deploy_utils.py index 3a7bd0487..09b0d1dd1 100644 --- a/tests/_test_utils/deploy_utils.py +++ b/tests/_test_utils/deploy_utils.py @@ -195,6 +195,8 @@ def _deploy_sglang(self): speculative_draft_model_path=self.model_id, tp_size=self.tensor_parallel_size, trust_remote_code=True, + mem_fraction_static=0.7, + context_length=1024, ) else: llm = sgl.Engine( From 6524e73c2027efdeecaab2981b383c2927eceb6b Mon Sep 17 00:00:00 2001 From: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> Date: Thu, 4 Dec 2025 06:40:29 +0000 Subject: [PATCH 5/6] remove quant method for eagle Signed-off-by: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> --- tests/_test_utils/deploy_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/_test_utils/deploy_utils.py b/tests/_test_utils/deploy_utils.py index 09b0d1dd1..7a1897ccb 100644 --- a/tests/_test_utils/deploy_utils.py +++ b/tests/_test_utils/deploy_utils.py @@ -187,7 +187,6 @@ def _deploy_sglang(self): if "eagle" in self.model_id.lower(): llm = sgl.Engine( model_path=self.base_model, - quantization=quantization_method, speculative_algorithm="EAGLE3", speculative_num_steps=3, speculative_eagle_topk=1, From 6e3cfa10c5f88ac26cc65c7dfdb22e5d55ad131b Mon Sep 17 00:00:00 2001 From: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> Date: Fri, 5 Dec 2025 01:48:04 +0000 Subject: [PATCH 6/6] add Kimi-K2-Thinking-NVFP4 Signed-off-by: noeyy-mino <174223378+noeyy-mino@users.noreply.github.com> --- tests/examples/gpt_oss/test_gpt_oss_qat.py | 192 +++++++++++++++------ tests/examples/llm_ptq/test_deploy.py | 6 + 2 files changed, 144 insertions(+), 54 deletions(-) diff --git a/tests/examples/gpt_oss/test_gpt_oss_qat.py b/tests/examples/gpt_oss/test_gpt_oss_qat.py index e1ad3ea2e..e5f9b8ab9 100644 --- a/tests/examples/gpt_oss/test_gpt_oss_qat.py +++ b/tests/examples/gpt_oss/test_gpt_oss_qat.py @@ -47,7 +47,10 @@ def __init__(self, model_path): self.model_path = model_path def gpt_oss_sft_training(self, tmp_path): - """Test supervised fine-tuning (SFT) of GPT-OSS-20B model - Step 1.""" + """Test supervised fine-tuning (SFT) of GPT-OSS-20B model - Step 1. + Returns: + Path to SFT output directory if successful, None otherwise + """ model_name = self.model_path.split("/")[-1] output_dir = tmp_path / f"{model_name}-sft" @@ -70,9 +73,16 @@ def gpt_oss_sft_training(self, tmp_path): # Verify SFT output directory exists assert output_dir.exists(), "SFT output directory should exist after training" + # Return the path to the SFT checkpoint + return output_dir + def gpt_oss_qat_training_lora(self, tmp_path): + """Test QAT training with LoRA for GPT-OSS-120B model - Step 1. + Returns: + Path to QAT-LoRA output directory if successful, None otherwise + """ model_name = self.model_path.split("/")[-1] - qat_output_dir = tmp_path / f"{model_name}-qat" + qat_output_dir = tmp_path / f"{model_name}-qat-lora" cmd_parts = [ "python", "sft.py", @@ -90,13 +100,22 @@ def gpt_oss_qat_training_lora(self, tmp_path): # Verify QAT output directory exists assert qat_output_dir.exists(), "QAT output directory should exist after training" - - def gpt_oss_qat_training(self, tmp_path): - """Test quantization-aware training (QAT) with MXFP4 configuration - Step 2.""" + # Return the path to the QAT-LoRA checkpoint + return qat_output_dir + + def gpt_oss_qat_training(self, tmp_path, sft_dir=None): + """Test quantization-aware training (QAT) with MXFP4 configuration - Step 2. + Args: + tmp_path: Base path for outputs + sft_dir: Path to SFT checkpoint from Step 1. If None, creates a mock one for standalone testing + Returns: + Path to QAT output directory if successful, None otherwise + """ # This test assumes test_gpt_oss_sft_training has been run first # Look for the SFT output directory from step 1 model_name = self.model_path.split("/")[-1] - sft_dir = tmp_path / f"{model_name}-sft" + if sft_dir is None: + sft_dir = tmp_path / f"{model_name}-sft" # If SFT directory doesn't exist, create a mock one for standalone testing if not sft_dir.exists(): @@ -140,32 +159,24 @@ def gpt_oss_qat_training(self, tmp_path): # Verify QAT output directory exists assert qat_output_dir.exists(), "QAT output directory should exist after training" - - def gpt_oss_mxfp4_conversion(self, tmp_path): - """Test conversion to MXFP4 weight-only format - Step 3.""" + # Return the path to the QAT checkpoint + return qat_output_dir + + def gpt_oss_mxfp4_conversion(self, tmp_path, qat_dir=None): + """Test conversion to MXFP4 weight-only format - Step 3. + Args: + tmp_path: Base path for outputs + qat_dir: Path to QAT checkpoint from Step 2. If None, creates a mock one for standalone testing + + Returns: + Path to MXFP4 conversion output directory if successful, None otherwise + """ # This test assumes test_gpt_oss_qat_training has been run first # Look for the QAT output directory from step 2 model_name = self.model_path.split("/")[-1] - qat_dir = tmp_path / f"{model_name}-qat" - # If QAT directory doesn't exist, create a mock one for standalone testing - if not qat_dir.exists(): - qat_dir.mkdir() - - # Create minimal config.json for the mock model - config_content = { - "model_type": "gpt_oss", - "hidden_size": 5120, - "num_attention_heads": 40, - "num_hidden_layers": 44, - "vocab_size": 100000, - "torch_dtype": "bfloat16", - } - - import json - - with open(qat_dir / "config.json", "w") as f: - json.dump(config_content, f) + if qat_dir is None: + qat_dir = tmp_path / f"{model_name}-qat" conversion_output_dir = tmp_path / f"{model_name}-qat-real-mxfp4" @@ -183,12 +194,57 @@ def gpt_oss_mxfp4_conversion(self, tmp_path): # Verify conversion output directory exists assert conversion_output_dir.exists(), "MXFP4 conversion output directory should exist" + # Return the path to the MXFP4 checkpoint + return conversion_output_dir + + def gpt_oss_mxfp4_conversion_lora(self, tmp_path, qat_lora_dir=None): + """Test conversion to MXFP4 weight-only format for LoRA model - Step 2. + Args: + tmp_path: Base path for outputs + qat_lora_dir: Path to QAT-LoRA checkpoint from Step 1. If None, uses default path + Returns: + Path to MXFP4 conversion output directory if successful, None otherwise + """ + # This test assumes test_gpt_oss_qat_training has been run first + # Look for the QAT output directory from step 2 + model_name = self.model_path.split("/")[-1] + if qat_lora_dir is None: + qat_lora_dir = tmp_path / f"{model_name}-qat-lora" + + conversion_output_dir = tmp_path / f"{model_name}-qat-real-mxfp4" - def deploy_gpt_oss_trtllm(self, tmp_path): - """Deploy GPT-OSS model with TensorRT-LLM.""" + # Command for MXFP4 conversion (Step 3) + cmd_parts = [ + "python", + "convert_oai_mxfp4_weight_only.py", + "--lora_path", + str(qat_lora_dir), + "--base_path", + self.model_path, + "--output_path", + str(conversion_output_dir), + ] + + run_example_command(cmd_parts, "gpt-oss") + + # Verify conversion output directory exists + assert conversion_output_dir.exists(), "MXFP4 conversion output directory should exist" + # Return the path to the MXFP4 checkpoint + return conversion_output_dir + + def deploy_gpt_oss_trtllm(self, tmp_path, model_path_override=None): + """Deploy GPT-OSS model with TensorRT-LLM. + Args: + tmp_path: Path for temporary files (benchmark data, reports) + model_path_override: Optional path to the model to deploy (e.g., MXFP4 checkpoint). + If None, uses self.model_path + """ # Skip if tensorrt_llm is not available pytest.importorskip("tensorrt_llm") + # Use override path if provided, otherwise use original model path + deploy_model_path = model_path_override if model_path_override else self.model_path + # Prepare benchmark data tensorrt_llm_workspace = "/app/tensorrt_llm" script = os.path.join(tensorrt_llm_workspace, "benchmarks", "cpp", "prepare_dataset.py") @@ -214,6 +270,8 @@ def deploy_gpt_oss_trtllm(self, tmp_path): "trtllm-bench", "--model", self.model_path, + "--model_path", + str(deploy_model_path), "throughput", "--backend", "pytorch", @@ -236,32 +294,58 @@ def deploy_gpt_oss_trtllm(self, tmp_path): ) def test_gpt_oss_complete_pipeline(model_path, tmp_path): """Test the complete GPT-OSS optimization pipeline by executing all 3 steps in sequence.""" + import pathlib + + # Use current directory instead of tmp_path for checkpoints + current_dir = pathlib.Path.cwd() # Create GPTOSS instance with model path gpt_oss = GPTOSS(model_path) - model_name = model_path.split("/")[-1] - # Execute Step 1: SFT Training if model_path == "openai/gpt-oss-20b": - gpt_oss.gpt_oss_sft_training(tmp_path) - # Execute Step 2: QAT Training - gpt_oss.gpt_oss_qat_training(tmp_path) - elif model_path == "openai/gpt-oss-120b": - # Execute QAT Training with LoRA - gpt_oss.gpt_oss_qat_training_lora(tmp_path) + # Step 1: SFT Training + sft_checkpoint = gpt_oss.gpt_oss_sft_training(current_dir) + if not sft_checkpoint or not sft_checkpoint.exists(): + print("Step 1 failed: SFT checkpoint not found, stopping pipeline.") + return + print(f"Step 1 completed: SFT checkpoint at {sft_checkpoint}") + + # Step 2: QAT Training (depends on Step 1) + qat_checkpoint = gpt_oss.gpt_oss_qat_training(current_dir, sft_dir=sft_checkpoint) + if not qat_checkpoint or not qat_checkpoint.exists(): + print("Step 2 failed: QAT checkpoint not found, stopping pipeline.") + return + print(f"Step 2 completed: QAT checkpoint at {qat_checkpoint}") + + # Step 3: MXFP4 Conversion (depends on Step 2) + mxfp4_checkpoint = gpt_oss.gpt_oss_mxfp4_conversion(current_dir, qat_dir=qat_checkpoint) + if not mxfp4_checkpoint or not mxfp4_checkpoint.exists(): + print("Step 3 failed: MXFP4 checkpoint not found, stopping pipeline.") + return + print(f"Step 3 completed: MXFP4 checkpoint at {mxfp4_checkpoint}") + + # Step 4: Deploy with TensorRT-LLM (depends on Step 3) + print("Step 4: Running deployment with MXFP4 checkpoint...") + gpt_oss.deploy_gpt_oss_trtllm(current_dir, model_path_override=mxfp4_checkpoint) + print("Step 4 completed: Deployment successful") - # Execute Step 3: MXFP4 Conversion - gpt_oss.gpt_oss_mxfp4_conversion(tmp_path) - - # Verify all output directories exist - qat_dir = tmp_path / f"{model_name}-qat" - conversion_dir = tmp_path / f"{model_name}-qat-real-mxfp4" - - assert qat_dir.exists(), "QAT output directory should exist after Step 2" - assert conversion_dir.exists(), "MXFP4 conversion output directory should exist after Step 3" - - print(f"Complete pipeline executed successfully for {model_path}!") - print(f"QAT output: {qat_dir}") - print(f"MXFP4 conversion output: {conversion_dir}") - - # Deploy with TensorRT-LLM - gpt_oss.deploy_gpt_oss_trtllm(tmp_path) + elif model_path == "openai/gpt-oss-120b": + # Step 1: QAT Training with LoRA + qat_lora_checkpoint = gpt_oss.gpt_oss_qat_training_lora(current_dir) + if not qat_lora_checkpoint or not qat_lora_checkpoint.exists(): + print("Step 1 failed: QAT-LoRA checkpoint not found, stopping pipeline.") + return + print(f"Step 1 completed: QAT-LoRA checkpoint at {qat_lora_checkpoint}") + + # Step 2: MXFP4 Conversion for LoRA model (depends on Step 1) + mxfp4_checkpoint = gpt_oss.gpt_oss_mxfp4_conversion_lora( + current_dir, qat_lora_dir=qat_lora_checkpoint + ) + if not mxfp4_checkpoint or not mxfp4_checkpoint.exists(): + print("Step 2 failed: MXFP4 checkpoint not found, stopping pipeline.") + return + print(f"Step 2 completed: MXFP4 checkpoint at {mxfp4_checkpoint}") + + # Step 3: Deploy with TensorRT-LLM (depends on Step 2) + print("Step 3: Running deployment with MXFP4 checkpoint...") + gpt_oss.deploy_gpt_oss_trtllm(current_dir, model_path_override=mxfp4_checkpoint) + print("Step 3 completed: Deployment successful") diff --git a/tests/examples/llm_ptq/test_deploy.py b/tests/examples/llm_ptq/test_deploy.py index 26ed313c9..3d3229e01 100644 --- a/tests/examples/llm_ptq/test_deploy.py +++ b/tests/examples/llm_ptq/test_deploy.py @@ -346,6 +346,12 @@ def test_phi(command): tensor_parallel_size=8, mini_sm=100, ), + *ModelDeployerList( + model_id="nvidia/Kimi-K2-Thinking-NVFP4", + backend=("trtllm", "vllm", "sglang"), + tensor_parallel_size=8, + mini_sm=100, + ), ], ids=idfn, )