@@ -12,8 +12,10 @@ def model_path():
1212 return llm_models_root () / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
1313
1414
15- def create_llm (model_dir ):
16- """Create LLM with specific overlap scheduler setting"""
15+ def _create_llm_base (model_dir , enable_trtllm_sampler ):
16+ """Base LLM creation with configurable sampler."""
17+ pytorch_config = dict (enable_trtllm_sampler = enable_trtllm_sampler )
18+
1719 trt_kv_cache_config = TRT_KvCacheConfig (enable_block_reuse = False )
1820
1921 return LLM (
@@ -22,10 +24,20 @@ def create_llm(model_dir):
2224 trust_remote_code = True ,
2325 enable_chunked_prefill = True ,
2426 cuda_graph_config = CudaGraphConfig (),
27+ ** pytorch_config ,
2528 kv_cache_config = trt_kv_cache_config ,
26- max_num_tokens =
27- 128 # Only one request longer than max_num_tokens is required to test chunked prefill
28- )
29+ max_num_tokens = 128
30+ ) # Only one request longer than max_num_tokens is required to test chunked prefill
31+
32+
33+ def create_llm (model_dir ):
34+ """Create LLM with specific overlap scheduler setting"""
35+ return _create_llm_base (model_dir , enable_trtllm_sampler = True )
36+
37+
38+ def create_llm_with_torch_sampler (model_dir ):
39+ """Create LLM with TorchSampler."""
40+ return _create_llm_base (model_dir , enable_trtllm_sampler = False )
2941
3042
3143@pytest .mark .high_cuda_memory
@@ -67,3 +79,69 @@ def test_trtllm_sampler(model_path):
6779 # Verify outputs are consistent
6880 for text , expected in zip (texts , expected_outputs ):
6981 assert similar (text , expected ), f"text: { text } , expected: { expected } "
82+
83+
84+ @pytest .mark .high_cuda_memory
85+ def test_trtllm_sampler_with_stop_token_ids (model_path ):
86+ """Test sampler with stop_token_ids (fast path optimization)."""
87+
88+ llm = create_llm_with_torch_sampler (model_path )
89+ tokenizer = llm .tokenizer
90+
91+ prompt = "The capital of France is"
92+ target_sentence = "The capital of France is Paris"
93+
94+ prompt_tokens = tokenizer .encode (prompt , add_special_tokens = False )
95+ target_tokens = tokenizer .encode (target_sentence , add_special_tokens = False )
96+
97+ # Use the first token after the prompt as the stop token
98+ assert len (target_tokens ) > len (
99+ prompt_tokens ), "Target must be longer than prompt"
100+ stop_token_id = target_tokens [len (prompt_tokens )]
101+
102+ sampling_config = SamplingParams (max_tokens = 100 ,
103+ n = 1 ,
104+ stop_token_ids = [stop_token_id ],
105+ temperature = 0.0 )
106+
107+ outputs = llm .generate ([prompt ], sampling_params = sampling_config )
108+ text = outputs [0 ].outputs [0 ].text
109+
110+ output_tokens = tokenizer .encode (text , add_special_tokens = False )
111+
112+ llm .shutdown ()
113+ assert stop_token_id not in output_tokens , f"Output should not contain stop token { stop_token_id } "
114+ assert len (output_tokens
115+ ) < 10 , "Should stop very early with first-token stop_token_id"
116+
117+
118+ @pytest .mark .high_cuda_memory
119+ def test_torch_sampler_with_multi_token_stop_words (model_path ):
120+ """Test TorchSampler with multi-token stop words (slow path)."""
121+
122+ llm = create_llm_with_torch_sampler (model_path )
123+ tokenizer = llm .tokenizer
124+
125+ prompt = "The capital of France is"
126+
127+ # Use a string that will tokenize to multiple tokens
128+ stop_string = "\n \n "
129+ stop_tokens = tokenizer .encode (stop_string , add_special_tokens = False )
130+
131+ assert len (
132+ stop_tokens
133+ ) > 1 , f"Stop string should be multi-token, got { len (stop_tokens )} tokens"
134+
135+ sampling_config = SamplingParams (
136+ max_tokens = 100 ,
137+ n = 1 ,
138+ stop = [stop_string ], # Use 'stop' parameter for multi-token
139+ temperature = 0.0 )
140+
141+ outputs = llm .generate ([prompt ], sampling_params = sampling_config )
142+ text = outputs [0 ].outputs [0 ].text
143+
144+ llm .shutdown ()
145+
146+ assert len (text ) > 0 , "Should generate some text"
147+ assert stop_string not in text , f"Stop string '{ repr (stop_string )} ' should not appear in the output"
0 commit comments