-
Couldn't load subscription status.
- Fork 17
Add ListMLE Loss #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
LakshmiKalaKadali
wants to merge
11
commits into
keras-team:main
Choose a base branch
from
LakshmiKalaKadali:listmleloss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add ListMLE Loss #130
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
730bc21
Add STU layer
LakshmiKalaKadali 1906f2d
Update keras_rs/src/layers/stu.py
LakshmiKalaKadali d0add4f
Update keras_rs/src/layers/common.py
LakshmiKalaKadali eb98ae1
Update keras_rs/src/layers/jagged_tensors.py
LakshmiKalaKadali 0347b51
Update keras_rs/src/layers/hstu_mha_attention.py
LakshmiKalaKadali 570df2b
Update keras_rs/src/layers/stu.py
LakshmiKalaKadali 6f3dcef
Update keras_rs/src/layers/stu.py
LakshmiKalaKadali c80e598
Update keras_rs/src/layers/jagged_tensors.py
LakshmiKalaKadali 8073dcb
Update keras_rs/src/layers/hstu_compute_output.py
LakshmiKalaKadali 8cb2d98
Add list mle loss
LakshmiKalaKadali b643778
Debug statements added
LakshmiKalaKadali File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,247 @@ | ||
| from typing import Any | ||
|
|
||
| import keras | ||
| from keras import ops | ||
|
|
||
| from keras_rs.src import types | ||
| from keras_rs.src.api_export import keras_rs_export | ||
| from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores | ||
| from keras_rs.src.metrics.utils import standardize_call_inputs_ranks | ||
|
|
||
|
|
||
| @keras_rs_export("keras_rs.losses.ListMLELoss") | ||
| class ListMLELoss(keras.losses.Loss): | ||
| """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking. | ||
|
|
||
| ListMLE loss is a listwise ranking loss that maximizes the likelihood of | ||
| the ground truth ranking. It works by: | ||
| 1. Sorting items by their relevance scores (labels) | ||
| 2. Computing the probability of observing this ranking given the | ||
| predicted scores | ||
| 3. Maximizing this likelihood (minimizing negative log-likelihood) | ||
|
|
||
| The loss is computed as the negative log-likelihood of the ground truth | ||
| ranking given the predicted scores: | ||
|
|
||
| ``` | ||
| loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i))) | ||
| ``` | ||
|
|
||
| where s_i is the predicted score for item i in the sorted order. | ||
|
|
||
| Args: | ||
| temperature: Temperature parameter for scaling logits. Higher values | ||
| make the probability distribution more uniform. Defaults to 1.0. | ||
| reduction: Type of reduction to apply to the loss. In almost all cases | ||
| this should be `"sum_over_batch_size"`. Supported options are | ||
| `"sum"`, `"sum_over_batch_size"`, `"mean"`, | ||
| `"mean_with_sample_weight"` or `None`. Defaults to | ||
| `"sum_over_batch_size"`. | ||
| name: Optional name for the loss instance. | ||
| dtype: The dtype of the loss's computations. Defaults to `None`. | ||
|
|
||
| Examples: | ||
| ```python | ||
| # Basic usage | ||
| loss_fn = ListMLELoss() | ||
|
|
||
| # With temperature scaling | ||
| loss_fn = ListMLELoss(temperature=0.5) | ||
|
|
||
| # Example with synthetic data | ||
| y_true = [[3, 2, 1, 0]] # Relevance scores | ||
| y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores | ||
| loss = loss_fn(y_true, y_pred) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, temperature: float = 1.0, debug: bool = True, **kwargs: Any | ||
| ) -> None: | ||
| super().__init__(**kwargs) | ||
|
|
||
| if temperature <= 0.0: | ||
| raise ValueError( | ||
| f"`temperature` should be a positive float. Received: " | ||
| f"`temperature` = {temperature}." | ||
| ) | ||
|
|
||
| self.temperature = temperature | ||
| self._epsilon = 1e-10 | ||
| self.debug = debug | ||
|
|
||
| def compute_unreduced_loss( | ||
| self, | ||
| labels: types.Tensor, | ||
| logits: types.Tensor, | ||
| mask: types.Tensor | None = None, | ||
| ) -> tuple[types.Tensor, types.Tensor]: | ||
| """Compute the unreduced ListMLE loss. | ||
|
|
||
| Args: | ||
| labels: Ground truth relevance scores of | ||
| shape [batch_size,list_size]. | ||
| logits: Predicted scores of shape [batch_size, list_size]. | ||
| mask: Optional mask of shape [batch_size, list_size]. | ||
|
|
||
| Returns: | ||
| Tuple of (losses, weights) where losses has shape [batch_size, 1] | ||
| and weights has the same shape. | ||
| """ | ||
|
|
||
| valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype)) | ||
|
|
||
| if mask is not None: | ||
| valid_mask = ops.logical_and( | ||
| valid_mask, ops.cast(mask, dtype="bool") | ||
| ) | ||
|
|
||
| num_valid_items = ops.sum( | ||
| ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True | ||
| ) | ||
|
|
||
| batch_has_valid_items = ops.greater(num_valid_items, 0.0) | ||
|
|
||
| labels_for_sorting = ops.where( | ||
| valid_mask, labels, ops.full_like(labels, -1e9) | ||
| ) | ||
| logits_masked = ops.where( | ||
| valid_mask, logits, ops.full_like(logits, -1e9) | ||
| ) | ||
| sorted_logits, sorted_valid_mask = sort_by_scores( | ||
| tensors_to_sort=[logits_masked, valid_mask], | ||
| scores=labels_for_sorting, | ||
| mask=None, | ||
| shuffle_ties=False, | ||
| seed=None, | ||
| ) | ||
| sorted_logits = ops.divide( | ||
| sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype) | ||
| ) | ||
|
|
||
| valid_logits_for_max = ops.where( | ||
| sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9) | ||
| ) | ||
| raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True) | ||
| raw_max = ops.where( | ||
| batch_has_valid_items, raw_max, ops.zeros_like(raw_max) | ||
| ) | ||
| sorted_logits = ops.subtract(sorted_logits, raw_max) | ||
|
|
||
| # Set invalid positions to very negative BEFORE exp | ||
| sorted_logits = ops.where( | ||
| sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9) | ||
| ) | ||
| exp_logits = ops.exp(sorted_logits) | ||
|
|
||
| reversed_exp = ops.flip(exp_logits, axis=1) | ||
| reversed_cumsum = ops.cumsum(reversed_exp, axis=1) | ||
| cumsum_from_right = ops.flip(reversed_cumsum, axis=1) | ||
|
|
||
| log_normalizers = ops.log(cumsum_from_right + self._epsilon) | ||
| log_probs = ops.subtract(sorted_logits, log_normalizers) | ||
|
|
||
| log_probs = ops.where( | ||
| sorted_valid_mask, log_probs, ops.zeros_like(log_probs) | ||
| ) | ||
|
|
||
| negative_log_likelihood = ops.negative( | ||
| ops.sum(log_probs, axis=1, keepdims=True) | ||
| ) | ||
|
|
||
| negative_log_likelihood = ops.where( | ||
| batch_has_valid_items, | ||
| negative_log_likelihood, | ||
| ops.zeros_like(negative_log_likelihood), | ||
| ) | ||
|
|
||
| weights = ops.ones_like(negative_log_likelihood) | ||
|
|
||
| # Debug print statements for all intermediate values | ||
| if self.debug: | ||
| import sys | ||
|
|
||
| def safe_print(label, value): | ||
| try: | ||
| # For TensorFlow, only print numpy if in eager mode | ||
| if hasattr(value, "numpy"): | ||
| print(label, value.numpy(), file=sys.stderr) | ||
| else: | ||
| print( | ||
| label, ops.convert_to_numpy(value), file=sys.stderr | ||
| ) | ||
| except Exception as e: | ||
| print(label, f"<error printing: {e}>", file=sys.stderr) | ||
|
|
||
| safe_print("valid_mask", valid_mask) | ||
| safe_print("num_valid_items", num_valid_items) | ||
| safe_print("batch_has_valid_items", batch_has_valid_items) | ||
| safe_print("labels_for_sorting", labels_for_sorting) | ||
| safe_print("logits_masked", logits_masked) | ||
| safe_print("sorted_logits", sorted_logits) | ||
| safe_print("sorted_valid_mask", sorted_valid_mask) | ||
| safe_print("raw_max", raw_max) | ||
| safe_print("exp_logits", exp_logits) | ||
| safe_print("reversed_exp", reversed_exp) | ||
| safe_print("reversed_cumsum", reversed_cumsum) | ||
| safe_print("cumsum_from_right", cumsum_from_right) | ||
| safe_print("log_normalizers", log_normalizers) | ||
| safe_print("log_probs", log_probs) | ||
| safe_print("negative_log_likelihood", negative_log_likelihood) | ||
| safe_print("weights", weights) | ||
|
|
||
| return negative_log_likelihood, weights | ||
|
|
||
| def call( | ||
| self, | ||
| y_true: types.Tensor, | ||
| y_pred: types.Tensor, | ||
| ) -> types.Tensor: | ||
| """Compute the ListMLE loss. | ||
|
|
||
| Args: | ||
| y_true: tensor or dict. Ground truth values. If tensor, of shape | ||
| `(list_size)` for unbatched inputs or `(batch_size, list_size)` | ||
| for batched inputs. If an item has a label of -1, it is ignored | ||
| in loss computation. If it is a dictionary, it should have two | ||
| keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore | ||
| elements in loss computation. | ||
| y_pred: tensor. The predicted values, of shape `(list_size)` for | ||
| unbatched inputs or `(batch_size, list_size)` for batched | ||
| inputs. Should be of the same shape as `y_true`. | ||
|
|
||
| Returns: | ||
| The loss tensor of shape [batch_size]. | ||
| """ | ||
| mask = None | ||
| if isinstance(y_true, dict): | ||
| if "labels" not in y_true: | ||
| raise ValueError( | ||
| '`"labels"` should be present in `y_true`. Received: ' | ||
| f"`y_true` = {y_true}" | ||
| ) | ||
|
|
||
| mask = y_true.get("mask", None) | ||
| y_true = y_true["labels"] | ||
|
|
||
| y_true = ops.convert_to_tensor(y_true) | ||
| y_pred = ops.convert_to_tensor(y_pred) | ||
| if mask is not None: | ||
| mask = ops.convert_to_tensor(mask) | ||
|
|
||
| y_true, y_pred, mask, _ = standardize_call_inputs_ranks( | ||
| y_true, y_pred, mask | ||
| ) | ||
|
|
||
| losses, weights = self.compute_unreduced_loss( | ||
| labels=y_true, logits=y_pred, mask=mask | ||
| ) | ||
| losses = ops.multiply(losses, weights) | ||
| losses = ops.squeeze(losses, axis=-1) | ||
| return losses | ||
|
|
||
| # getting config | ||
| def get_config(self) -> dict[str, Any]: | ||
| config: dict[str, Any] = super().get_config() | ||
| config.update({"temperature": self.temperature}) | ||
| return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| import keras | ||
| from absl.testing import parameterized | ||
| from keras import ops | ||
| from keras.losses import deserialize | ||
| from keras.losses import serialize | ||
|
|
||
| from keras_rs.src import testing | ||
| from keras_rs.src.losses.list_mle_loss import ListMLELoss | ||
|
|
||
|
|
||
| class ListMLELossTest(testing.TestCase, parameterized.TestCase): | ||
| def setUp(self): | ||
| self.unbatched_scores = ops.array( | ||
| [1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32" | ||
| ) | ||
| self.unbatched_labels = ops.array( | ||
| [1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32" | ||
| ) | ||
| self.batched_scores = ops.array( | ||
| [[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]], | ||
| dtype="float32", | ||
| ) | ||
| self.batched_labels = ops.array( | ||
| [[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]], | ||
| dtype="float32", | ||
| ) | ||
| self.expected_output = ops.array([6.865693, 3.088192], dtype="float32") | ||
|
|
||
| def test_unbatched_input(self): | ||
| loss = ListMLELoss(reduction="none") | ||
| output = loss( | ||
| y_true=self.unbatched_labels, y_pred=self.unbatched_scores | ||
| ) | ||
| self.assertEqual(output.shape, (1,)) | ||
| self.assertTrue(ops.convert_to_numpy(output[0]) > 0) | ||
| self.assertAllClose(output, [self.expected_output[0]], atol=1e-5) | ||
|
|
||
| def test_batched_input(self): | ||
| loss = ListMLELoss(reduction="none") | ||
| output = loss(y_true=self.batched_labels, y_pred=self.batched_scores) | ||
| self.assertEqual(output.shape, (2,)) | ||
| self.assertTrue(ops.convert_to_numpy(output[0]) > 0) | ||
| self.assertTrue(ops.convert_to_numpy(output[1]) > 0) | ||
| self.assertAllClose(output, self.expected_output, atol=1e-5) | ||
|
|
||
| def test_temperature(self): | ||
| loss_temp = ListMLELoss(temperature=0.5, reduction="none") | ||
| output_temp = loss_temp( | ||
| y_true=self.batched_labels, y_pred=self.batched_scores | ||
| ) | ||
| self.assertAllClose( | ||
| output_temp, | ||
| [10.969891, 2.1283305], | ||
| atol=1e-5, | ||
| ) | ||
|
|
||
| def test_invalid_input_rank(self): | ||
| rank_1_input = ops.ones((2, 3, 4)) | ||
|
|
||
| loss = ListMLELoss() | ||
| with self.assertRaises(ValueError): | ||
| loss(y_true=rank_1_input, y_pred=rank_1_input) | ||
|
|
||
| def test_loss_reduction(self): | ||
| loss = ListMLELoss(reduction="sum_over_batch_size") | ||
| output = loss(y_true=self.batched_labels, y_pred=self.batched_scores) | ||
| self.assertAlmostEqual( | ||
| ops.convert_to_numpy(output), 4.9769425, places=5 | ||
| ) | ||
|
|
||
| def test_scalar_sample_weight(self): | ||
| sample_weight = ops.array(5.0) | ||
| loss = ListMLELoss(reduction="none") | ||
|
|
||
| output = loss( | ||
| y_true=self.batched_labels, | ||
| y_pred=self.batched_scores, | ||
| sample_weight=sample_weight, | ||
| ) | ||
|
|
||
| self.assertAllClose( | ||
| output, self.expected_output * sample_weight, atol=1e-5 | ||
| ) | ||
|
|
||
| def test_model_fit(self): | ||
| inputs = keras.Input(shape=(20,), dtype="float32") | ||
| outputs = keras.layers.Dense(5)(inputs) | ||
| model = keras.Model(inputs=inputs, outputs=outputs) | ||
|
|
||
| model.compile(loss=ListMLELoss(), optimizer="adam") | ||
| model.fit( | ||
| x=keras.random.normal((2, 20)), | ||
| y=keras.random.randint((2, 5), minval=0, maxval=2), | ||
| ) | ||
|
|
||
| def test_serialization(self): | ||
| loss = ListMLELoss(temperature=0.8) | ||
| restored = deserialize(serialize(loss)) | ||
| self.assertDictEqual(loss.get_config(), restored.get_config()) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.