diff --git a/python/tests/test_transformers.py b/python/tests/test_transformers.py index 3c35445fa..db00ca228 100644 --- a/python/tests/test_transformers.py +++ b/python/tests/test_transformers.py @@ -1020,6 +1020,6 @@ def test_transformers_wav2vec2( predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode(predicted_ids, output_word_offsets=True) - transcription = transcription[0].replace(processor.tokenizer.unk_token, "") + transcriptions = transcription[0].replace(processor.tokenizer.unk_token, "") - assert transcription == expected_transcription[0] + assert transcriptions == expected_transcription[0]