@@ -154,20 +154,23 @@ def test_generate_with_return_logits(
154154 else :
155155 assert output .context_logits is None
156156
157- if gather_generation_logits :
158- gen_logits = output .outputs [0 ].generation_logits
159- assert gen_logits is not None
160- assert gen_logits .ndim == 2
161- assert gen_logits .shape [0 ] == sampling_params .max_tokens
162- assert torch .argmax (gen_logits ,
163- dim = 1 ).tolist () == output .outputs [0 ].token_ids
164- else :
165- assert output .outputs [0 ].generation_logits is None
157+ for sequence in output .outputs :
158+ assert sequence .length == sampling_params .max_tokens
159+
160+ if gather_generation_logits :
161+ gen_logits = sequence .generation_logits
162+ assert gen_logits is not None
163+ assert gen_logits .ndim == 2
164+ assert gen_logits .shape [0 ] == sampling_params .max_tokens
165+ assert torch .argmax (gen_logits ,
166+ dim = 1 ).tolist () == sequence .token_ids
167+ else :
168+ assert sequence .generation_logits is None
166169
167- if return_log_probs :
168- assert len (output . outputs [ 0 ] .logprobs ) == sampling_params .max_tokens
169- else :
170- assert len (output . outputs [ 0 ] .logprobs ) == 0
170+ if return_log_probs :
171+ assert len (sequence .logprobs ) == sampling_params .max_tokens
172+ else :
173+ assert len (sequence .logprobs ) == 0
171174
172175
173176@force_ampere # Save H100 resource
@@ -218,22 +221,24 @@ def test_generate_async_with_return_logits(
218221 else :
219222 assert output .context_logits is None
220223
221- if gather_generation_logits :
222- gen_logits = output .outputs [0 ].generation_logits
223- assert gen_logits is not None
224- assert gen_logits .ndim == 2
225- assert gen_logits .shape [0 ] == 1
226- try :
227- assert torch .argmax (
228- gen_logits ,
229- dim = 1 ).tolist ()[0 ] == output .outputs [0 ].token_ids [- 1 ]
230- except AssertionError :
231- # FIXME: Remove xfail once the bug is fixed
232- pytest .xfail ("Known bug: https://nvbugs/5573238" )
233- else :
234- assert output .outputs [0 ].generation_logits is None
224+ for sequence in output .outputs :
225+ assert sequence .length == idx + 1
226+
227+ if gather_generation_logits :
228+ gen_logits = sequence .generation_logits
229+ assert gen_logits is not None
230+ assert gen_logits .ndim == 2
231+ assert gen_logits .shape [0 ] == 1
232+ try :
233+ assert torch .argmax (
234+ gen_logits , dim = 1 ).tolist ()[0 ] == sequence .token_ids [- 1 ]
235+ except AssertionError :
236+ # FIXME: Remove xfail once the bug is fixed
237+ pytest .xfail ("Known bug: https://nvbugs/5573238" )
238+ else :
239+ assert sequence .generation_logits is None
235240
236- if return_log_probs :
237- assert len (output . outputs [ 0 ] .logprobs ) == idx + 1
238- else :
239- assert len (output . outputs [ 0 ] .logprobs ) == 0
241+ if return_log_probs :
242+ assert len (sequence .logprobs ) == idx + 1
243+ else :
244+ assert len (sequence .logprobs ) == 0
0 commit comments