Skip to content

Commit 10a1cf0

Browse files
committed
Improve test_return_logits
- Add loop over sequences in test_generate_with_return_logits and test_generate_async_with_return_logits. - Add assertion on sequence length in test_generate_with_return_logits and test_generate_async_with_return_logits. Signed-off-by: Robin Kobus <[email protected]>
1 parent cd8f8a8 commit 10a1cf0

File tree

1 file changed

+36
-31
lines changed

1 file changed

+36
-31
lines changed

tests/unittest/_torch/sampler/test_return_logits.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -154,20 +154,23 @@ def test_generate_with_return_logits(
154154
else:
155155
assert output.context_logits is None
156156

157-
if gather_generation_logits:
158-
gen_logits = output.outputs[0].generation_logits
159-
assert gen_logits is not None
160-
assert gen_logits.ndim == 2
161-
assert gen_logits.shape[0] == sampling_params.max_tokens
162-
assert torch.argmax(gen_logits,
163-
dim=1).tolist() == output.outputs[0].token_ids
164-
else:
165-
assert output.outputs[0].generation_logits is None
157+
for sequence in output.outputs:
158+
assert sequence.length == sampling_params.max_tokens
159+
160+
if gather_generation_logits:
161+
gen_logits = sequence.generation_logits
162+
assert gen_logits is not None
163+
assert gen_logits.ndim == 2
164+
assert gen_logits.shape[0] == sampling_params.max_tokens
165+
assert torch.argmax(gen_logits,
166+
dim=1).tolist() == sequence.token_ids
167+
else:
168+
assert sequence.generation_logits is None
166169

167-
if return_log_probs:
168-
assert len(output.outputs[0].logprobs) == sampling_params.max_tokens
169-
else:
170-
assert len(output.outputs[0].logprobs) == 0
170+
if return_log_probs:
171+
assert len(sequence.logprobs) == sampling_params.max_tokens
172+
else:
173+
assert len(sequence.logprobs) == 0
171174

172175

173176
@force_ampere # Save H100 resource
@@ -218,22 +221,24 @@ def test_generate_async_with_return_logits(
218221
else:
219222
assert output.context_logits is None
220223

221-
if gather_generation_logits:
222-
gen_logits = output.outputs[0].generation_logits
223-
assert gen_logits is not None
224-
assert gen_logits.ndim == 2
225-
assert gen_logits.shape[0] == 1
226-
try:
227-
assert torch.argmax(
228-
gen_logits,
229-
dim=1).tolist()[0] == output.outputs[0].token_ids[-1]
230-
except AssertionError:
231-
# FIXME: Remove xfail once the bug is fixed
232-
pytest.xfail("Known bug: https://nvbugs/5573238")
233-
else:
234-
assert output.outputs[0].generation_logits is None
224+
for sequence in output.outputs:
225+
assert sequence.length == idx + 1
226+
227+
if gather_generation_logits:
228+
gen_logits = sequence.generation_logits
229+
assert gen_logits is not None
230+
assert gen_logits.ndim == 2
231+
assert gen_logits.shape[0] == 1
232+
try:
233+
assert torch.argmax(
234+
gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1]
235+
except AssertionError:
236+
# FIXME: Remove xfail once the bug is fixed
237+
pytest.xfail("Known bug: https://nvbugs/5573238")
238+
else:
239+
assert sequence.generation_logits is None
235240

236-
if return_log_probs:
237-
assert len(output.outputs[0].logprobs) == idx + 1
238-
else:
239-
assert len(output.outputs[0].logprobs) == 0
241+
if return_log_probs:
242+
assert len(sequence.logprobs) == idx + 1
243+
else:
244+
assert len(sequence.logprobs) == 0

0 commit comments

Comments
 (0)