Skip to content

Commit af7cf97

Browse files
committed
First draft
Signed-off-by: Michal Guzek <[email protected]>
1 parent 0649b77 commit af7cf97

File tree

3 files changed

+100
-7
lines changed

3 files changed

+100
-7
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,12 +814,26 @@ def _meet_max_token_stop_criteria(self, request: LlmRequest):
814814
>= self.max_seq_len)
815815

816816
@staticmethod
817-
def _meet_stop_token_criteria(request: LlmRequest):
817+
def _meet_stop_token_criteria(request: LlmRequest, new_token: int):
818818
if request.py_stop_words_list:
819819
assert isinstance(
820820
request.py_stop_words_list,
821821
list), "request.py_stop_words_list should be a list"
822822
stop_words_list, prefix_sum = request.py_stop_words_list
823+
824+
# Determine max stop word length to decide optimization path
825+
max_stop_word_length = prefix_sum[0] if prefix_sum else 0
826+
for i in range(1, len(prefix_sum)):
827+
word_length = prefix_sum[i] - prefix_sum[i - 1]
828+
max_stop_word_length = max(max_stop_word_length, word_length)
829+
830+
# Fast path: all stop words are single tokens
831+
if max_stop_word_length == 1:
832+
if new_token in stop_words_list:
833+
return True
834+
return False
835+
836+
# Slow path: at least one multi-token stop word exists
823837
tokens = request.get_tokens(0)
824838
offset = 0
825839
for i, offset_end in enumerate(prefix_sum):
@@ -844,7 +858,7 @@ def _handle_stop_criteria(self, request: LlmRequest,
844858
request.finish_by(FinishReason.LENGTH, self.BEAM)
845859
return True
846860

847-
if self._meet_stop_token_criteria(request):
861+
if self._meet_stop_token_criteria(request, new_token):
848862
request.finish_by(FinishReason.STOP_WORDS, self.BEAM)
849863
return True
850864

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ l0_a10:
1616
# ------------- PyTorch tests ---------------
1717
- unittest/_torch/modeling/test_modeling_mistral.py
1818
- unittest/_torch/modeling/test_modeling_pixtral.py
19+
- unittest/_torch/test_trtllm_sampler.py
1920
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
2021
# test list either).
2122
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py

tests/unittest/_torch/sampler/test_trtllm_sampler.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ def model_path():
1212
return llm_models_root() / "llama-models-v2/TinyLlama-1.1B-Chat-v1.0"
1313

1414

15-
def create_llm(model_dir):
16-
"""Create LLM with specific overlap scheduler setting"""
15+
def _create_llm_base(model_dir, enable_trtllm_sampler):
16+
"""Base LLM creation with configurable sampler."""
17+
pytorch_config = dict(enable_trtllm_sampler=enable_trtllm_sampler)
18+
1719
trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False)
1820

1921
return LLM(
@@ -22,10 +24,20 @@ def create_llm(model_dir):
2224
trust_remote_code=True,
2325
enable_chunked_prefill=True,
2426
cuda_graph_config=CudaGraphConfig(),
27+
**pytorch_config,
2528
kv_cache_config=trt_kv_cache_config,
26-
max_num_tokens=
27-
128 # Only one request longer than max_num_tokens is required to test chunked prefill
28-
)
29+
max_num_tokens=128
30+
) # Only one request longer than max_num_tokens is required to test chunked prefill
31+
32+
33+
def create_llm(model_dir):
34+
"""Create LLM with specific overlap scheduler setting"""
35+
return _create_llm_base(model_dir, enable_trtllm_sampler=True)
36+
37+
38+
def create_llm_with_torch_sampler(model_dir):
39+
"""Create LLM with TorchSampler."""
40+
return _create_llm_base(model_dir, enable_trtllm_sampler=False)
2941

3042

3143
@pytest.mark.high_cuda_memory
@@ -67,3 +79,69 @@ def test_trtllm_sampler(model_path):
6779
# Verify outputs are consistent
6880
for text, expected in zip(texts, expected_outputs):
6981
assert similar(text, expected), f"text: {text}, expected: {expected}"
82+
83+
84+
@pytest.mark.high_cuda_memory
85+
def test_trtllm_sampler_with_stop_token_ids(model_path):
86+
"""Test sampler with stop_token_ids (fast path optimization)."""
87+
88+
llm = create_llm_with_torch_sampler(model_path)
89+
tokenizer = llm.tokenizer
90+
91+
prompt = "The capital of France is"
92+
target_sentence = "The capital of France is Paris"
93+
94+
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
95+
target_tokens = tokenizer.encode(target_sentence, add_special_tokens=False)
96+
97+
# Use the first token after the prompt as the stop token
98+
assert len(target_tokens) > len(
99+
prompt_tokens), "Target must be longer than prompt"
100+
stop_token_id = target_tokens[len(prompt_tokens)]
101+
102+
sampling_config = SamplingParams(max_tokens=100,
103+
n=1,
104+
stop_token_ids=[stop_token_id],
105+
temperature=0.0)
106+
107+
outputs = llm.generate([prompt], sampling_params=sampling_config)
108+
text = outputs[0].outputs[0].text
109+
110+
output_tokens = tokenizer.encode(text, add_special_tokens=False)
111+
112+
llm.shutdown()
113+
assert stop_token_id not in output_tokens, f"Output should not contain stop token {stop_token_id}"
114+
assert len(output_tokens
115+
) < 10, "Should stop very early with first-token stop_token_id"
116+
117+
118+
@pytest.mark.high_cuda_memory
119+
def test_torch_sampler_with_multi_token_stop_words(model_path):
120+
"""Test TorchSampler with multi-token stop words (slow path)."""
121+
122+
llm = create_llm_with_torch_sampler(model_path)
123+
tokenizer = llm.tokenizer
124+
125+
prompt = "The capital of France is"
126+
127+
# Use a string that will tokenize to multiple tokens
128+
stop_string = "\n\n"
129+
stop_tokens = tokenizer.encode(stop_string, add_special_tokens=False)
130+
131+
assert len(
132+
stop_tokens
133+
) > 1, f"Stop string should be multi-token, got {len(stop_tokens)} tokens"
134+
135+
sampling_config = SamplingParams(
136+
max_tokens=100,
137+
n=1,
138+
stop=[stop_string], # Use 'stop' parameter for multi-token
139+
temperature=0.0)
140+
141+
outputs = llm.generate([prompt], sampling_params=sampling_config)
142+
text = outputs[0].outputs[0].text
143+
144+
llm.shutdown()
145+
146+
assert len(text) > 0, "Should generate some text"
147+
assert stop_string not in text, f"Stop string '{repr(stop_string)}' should not appear in the output"

0 commit comments

Comments
 (0)