From 3ab90a400b76a6fb96ef19af4c463a7a1f7c7c76 Mon Sep 17 00:00:00 2001 From: Shashank Kapoor Date: Sun, 9 Mar 2025 14:29:05 -0700 Subject: [PATCH 1/3] Adding logits task on Whisper Backbone --- keras_hub/src/models/whisper/whisper_backbone.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/keras_hub/src/models/whisper/whisper_backbone.py b/keras_hub/src/models/whisper/whisper_backbone.py index 1bbd5e70af..3f04fecb45 100644 --- a/keras_hub/src/models/whisper/whisper_backbone.py +++ b/keras_hub/src/models/whisper/whisper_backbone.py @@ -286,3 +286,13 @@ def get_config(self): } ) return config + + def logits(self, *args, **kwargs): + result = self(*args, **kwargs) + token_embedding = None + for embedding_type in self.decoder_embeddings: + if "token_embedding" in embedding_type.path: + token_embedding = embedding_type + return keras.ops.matmul( + result["decoder_sequence_output"], keras.ops.transpose(token_embedding) + ) From d291bd1687ecc3e12a4cbe390f2dee2a02e7d9a6 Mon Sep 17 00:00:00 2001 From: Shashank Kapoor Date: Sun, 9 Mar 2025 14:37:25 -0700 Subject: [PATCH 2/3] Adding logits task on Whisper Backbone --- keras_hub/src/models/whisper/whisper_backbone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras_hub/src/models/whisper/whisper_backbone.py b/keras_hub/src/models/whisper/whisper_backbone.py index 3f04fecb45..0a1f3d86ba 100644 --- a/keras_hub/src/models/whisper/whisper_backbone.py +++ b/keras_hub/src/models/whisper/whisper_backbone.py @@ -290,7 +290,7 @@ def get_config(self): def logits(self, *args, **kwargs): result = self(*args, **kwargs) token_embedding = None - for embedding_type in self.decoder_embeddings: + for embedding_type in self.decoder_embeddings.weights: if "token_embedding" in embedding_type.path: token_embedding = embedding_type return keras.ops.matmul( From 8ac4b959cf74544ee8ecc782857a2a56049a780d Mon Sep 17 00:00:00 2001 From: Shashank Kapoor Date: Sun, 9 Mar 2025 18:47:16 -0700 Subject: [PATCH 3/3] Adding Logits function to WhisperBackbone --- .../models/whisper/whisper_backbone_test.py | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/models/whisper/whisper_backbone_test.py b/keras_hub/src/models/whisper/whisper_backbone_test.py index c80e0e3c34..e6a3dcc593 100644 --- a/keras_hub/src/models/whisper/whisper_backbone_test.py +++ b/keras_hub/src/models/whisper/whisper_backbone_test.py @@ -70,9 +70,7 @@ def test_smallest_preset(self): "decoder_token_ids": ops.array( [[50257, 50362, 464, 2068, 7586, 21831, 13, 50256, 50256]] ), - "decoder_padding_mask": ops.array( - [[1, 1, 1, 1, 1, 1, 1, 1, 0]] - ), + "decoder_padding_mask": ops.array([[1, 1, 1, 1, 1, 1, 1, 1, 0]]), }, expected_output_shape={ "encoder_sequence_output": (1, 1500, 384), @@ -89,6 +87,23 @@ def test_smallest_preset(self): }, ) + @pytest.mark.large + def test_logits(self): + backbone_cls = WhisperBackbone.from_preset("whisper_tiny_en") + input_data = { + "encoder_features": ops.ones((1, 3000, 80)), + "decoder_token_ids": ops.array( + [[50257, 50362, 464, 2068, 7586, 21831, 13, 50256, 50256]] + ), + "decoder_padding_mask": ops.array([[1, 1, 1, 1, 1, 1, 1, 1, 1]]), + } + logits = backbone_cls.logits(input_data) + self.assertEqual(logits.shape, (1, 9, 51864)) + self.assertAllEqual( + ops.argmax(ops.squeeze(logits, axis=0), axis=-1), + [50361, 357, 50256, 395, 263, 50256, 50256, 50256, 50256], + ) + @pytest.mark.extra_large def test_all_presets(self): for preset in WhisperBackbone.presets: