Skip to content
1 change: 1 addition & 0 deletions keras_rs/api/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
since your modifications would be overwritten.
"""

from keras_rs.src.losses.list_mle_loss import ListMLELoss as ListMLELoss
from keras_rs.src.losses.pairwise_hinge_loss import (
PairwiseHingeLoss as PairwiseHingeLoss,
)
Expand Down
247 changes: 247 additions & 0 deletions keras_rs/src/losses/list_mle_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from typing import Any

import keras
from keras import ops

from keras_rs.src import types
from keras_rs.src.api_export import keras_rs_export
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks


@keras_rs_export("keras_rs.losses.ListMLELoss")
class ListMLELoss(keras.losses.Loss):
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.

ListMLE loss is a listwise ranking loss that maximizes the likelihood of
the ground truth ranking. It works by:
1. Sorting items by their relevance scores (labels)
2. Computing the probability of observing this ranking given the
predicted scores
3. Maximizing this likelihood (minimizing negative log-likelihood)

The loss is computed as the negative log-likelihood of the ground truth
ranking given the predicted scores:

```
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
```

where s_i is the predicted score for item i in the sorted order.

Args:
temperature: Temperature parameter for scaling logits. Higher values
make the probability distribution more uniform. Defaults to 1.0.
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`. Supported options are
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
`"mean_with_sample_weight"` or `None`. Defaults to
`"sum_over_batch_size"`.
name: Optional name for the loss instance.
dtype: The dtype of the loss's computations. Defaults to `None`.

Examples:
```python
# Basic usage
loss_fn = ListMLELoss()

# With temperature scaling
loss_fn = ListMLELoss(temperature=0.5)

# Example with synthetic data
y_true = [[3, 2, 1, 0]] # Relevance scores
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
loss = loss_fn(y_true, y_pred)
```
"""

def __init__(
self, temperature: float = 1.0, debug: bool = True, **kwargs: Any
) -> None:
super().__init__(**kwargs)

if temperature <= 0.0:
raise ValueError(
f"`temperature` should be a positive float. Received: "
f"`temperature` = {temperature}."
)

self.temperature = temperature
self._epsilon = 1e-10
self.debug = debug

def compute_unreduced_loss(
self,
labels: types.Tensor,
logits: types.Tensor,
mask: types.Tensor | None = None,
) -> tuple[types.Tensor, types.Tensor]:
"""Compute the unreduced ListMLE loss.

Args:
labels: Ground truth relevance scores of
shape [batch_size,list_size].
logits: Predicted scores of shape [batch_size, list_size].
mask: Optional mask of shape [batch_size, list_size].

Returns:
Tuple of (losses, weights) where losses has shape [batch_size, 1]
and weights has the same shape.
"""

valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))

if mask is not None:
valid_mask = ops.logical_and(
valid_mask, ops.cast(mask, dtype="bool")
)

num_valid_items = ops.sum(
ops.cast(valid_mask, dtype=labels.dtype), axis=1, keepdims=True
)

batch_has_valid_items = ops.greater(num_valid_items, 0.0)

labels_for_sorting = ops.where(
valid_mask, labels, ops.full_like(labels, -1e9)
)
logits_masked = ops.where(
valid_mask, logits, ops.full_like(logits, -1e9)
)
sorted_logits, sorted_valid_mask = sort_by_scores(
tensors_to_sort=[logits_masked, valid_mask],
scores=labels_for_sorting,
mask=None,
shuffle_ties=False,
seed=None,
)
sorted_logits = ops.divide(
sorted_logits, ops.cast(self.temperature, dtype=sorted_logits.dtype)
)

valid_logits_for_max = ops.where(
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
)
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
raw_max = ops.where(
batch_has_valid_items, raw_max, ops.zeros_like(raw_max)
)
sorted_logits = ops.subtract(sorted_logits, raw_max)

# Set invalid positions to very negative BEFORE exp
sorted_logits = ops.where(
sorted_valid_mask, sorted_logits, ops.full_like(sorted_logits, -1e9)
)
exp_logits = ops.exp(sorted_logits)

reversed_exp = ops.flip(exp_logits, axis=1)
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)

log_normalizers = ops.log(cumsum_from_right + self._epsilon)
log_probs = ops.subtract(sorted_logits, log_normalizers)

log_probs = ops.where(
sorted_valid_mask, log_probs, ops.zeros_like(log_probs)
)

negative_log_likelihood = ops.negative(
ops.sum(log_probs, axis=1, keepdims=True)
)

negative_log_likelihood = ops.where(
batch_has_valid_items,
negative_log_likelihood,
ops.zeros_like(negative_log_likelihood),
)

weights = ops.ones_like(negative_log_likelihood)

# Debug print statements for all intermediate values
if self.debug:
import sys

def safe_print(label, value):
try:
# For TensorFlow, only print numpy if in eager mode
if hasattr(value, "numpy"):
print(label, value.numpy(), file=sys.stderr)
else:
print(
label, ops.convert_to_numpy(value), file=sys.stderr
)
except Exception as e:
print(label, f"<error printing: {e}>", file=sys.stderr)

safe_print("valid_mask", valid_mask)
safe_print("num_valid_items", num_valid_items)
safe_print("batch_has_valid_items", batch_has_valid_items)
safe_print("labels_for_sorting", labels_for_sorting)
safe_print("logits_masked", logits_masked)
safe_print("sorted_logits", sorted_logits)
safe_print("sorted_valid_mask", sorted_valid_mask)
safe_print("raw_max", raw_max)
safe_print("exp_logits", exp_logits)
safe_print("reversed_exp", reversed_exp)
safe_print("reversed_cumsum", reversed_cumsum)
safe_print("cumsum_from_right", cumsum_from_right)
safe_print("log_normalizers", log_normalizers)
safe_print("log_probs", log_probs)
safe_print("negative_log_likelihood", negative_log_likelihood)
safe_print("weights", weights)

return negative_log_likelihood, weights

def call(
self,
y_true: types.Tensor,
y_pred: types.Tensor,
) -> types.Tensor:
"""Compute the ListMLE loss.

Args:
y_true: tensor or dict. Ground truth values. If tensor, of shape
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
for batched inputs. If an item has a label of -1, it is ignored
in loss computation. If it is a dictionary, it should have two
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
elements in loss computation.
y_pred: tensor. The predicted values, of shape `(list_size)` for
unbatched inputs or `(batch_size, list_size)` for batched
inputs. Should be of the same shape as `y_true`.

Returns:
The loss tensor of shape [batch_size].
"""
mask = None
if isinstance(y_true, dict):
if "labels" not in y_true:
raise ValueError(
'`"labels"` should be present in `y_true`. Received: '
f"`y_true` = {y_true}"
)

mask = y_true.get("mask", None)
y_true = y_true["labels"]

y_true = ops.convert_to_tensor(y_true)
y_pred = ops.convert_to_tensor(y_pred)
if mask is not None:
mask = ops.convert_to_tensor(mask)

y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
y_true, y_pred, mask
)

losses, weights = self.compute_unreduced_loss(
labels=y_true, logits=y_pred, mask=mask
)
losses = ops.multiply(losses, weights)
losses = ops.squeeze(losses, axis=-1)
return losses

# getting config
def get_config(self) -> dict[str, Any]:
config: dict[str, Any] = super().get_config()
config.update({"temperature": self.temperature})
return config
99 changes: 99 additions & 0 deletions keras_rs/src/losses/list_mle_loss_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras.losses import deserialize
from keras.losses import serialize

from keras_rs.src import testing
from keras_rs.src.losses.list_mle_loss import ListMLELoss


class ListMLELossTest(testing.TestCase, parameterized.TestCase):
def setUp(self):
self.unbatched_scores = ops.array(
[1.0, 3.0, 2.0, 4.0, 0.8], dtype="float32"
)
self.unbatched_labels = ops.array(
[1.0, 0.0, 1.0, 3.0, 2.0], dtype="float32"
)
self.batched_scores = ops.array(
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]],
dtype="float32",
)
self.batched_labels = ops.array(
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]],
dtype="float32",
)
self.expected_output = ops.array([6.865693, 3.088192], dtype="float32")

def test_unbatched_input(self):
loss = ListMLELoss(reduction="none")
output = loss(
y_true=self.unbatched_labels, y_pred=self.unbatched_scores
)
self.assertEqual(output.shape, (1,))
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
self.assertAllClose(output, [self.expected_output[0]], atol=1e-5)

def test_batched_input(self):
loss = ListMLELoss(reduction="none")
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
self.assertEqual(output.shape, (2,))
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
self.assertTrue(ops.convert_to_numpy(output[1]) > 0)
self.assertAllClose(output, self.expected_output, atol=1e-5)

def test_temperature(self):
loss_temp = ListMLELoss(temperature=0.5, reduction="none")
output_temp = loss_temp(
y_true=self.batched_labels, y_pred=self.batched_scores
)
self.assertAllClose(
output_temp,
[10.969891, 2.1283305],
atol=1e-5,
)

def test_invalid_input_rank(self):
rank_1_input = ops.ones((2, 3, 4))

loss = ListMLELoss()
with self.assertRaises(ValueError):
loss(y_true=rank_1_input, y_pred=rank_1_input)

def test_loss_reduction(self):
loss = ListMLELoss(reduction="sum_over_batch_size")
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
self.assertAlmostEqual(
ops.convert_to_numpy(output), 4.9769425, places=5
)

def test_scalar_sample_weight(self):
sample_weight = ops.array(5.0)
loss = ListMLELoss(reduction="none")

output = loss(
y_true=self.batched_labels,
y_pred=self.batched_scores,
sample_weight=sample_weight,
)

self.assertAllClose(
output, self.expected_output * sample_weight, atol=1e-5
)

def test_model_fit(self):
inputs = keras.Input(shape=(20,), dtype="float32")
outputs = keras.layers.Dense(5)(inputs)
model = keras.Model(inputs=inputs, outputs=outputs)

model.compile(loss=ListMLELoss(), optimizer="adam")
model.fit(
x=keras.random.normal((2, 20)),
y=keras.random.randint((2, 5), minval=0, maxval=2),
)

def test_serialization(self):
loss = ListMLELoss(temperature=0.8)
restored = deserialize(serialize(loss))
self.assertDictEqual(loss.get_config(), restored.get_config())
Loading