1+ from  typing  import  Any 
2+ 
3+ import  keras 
4+ from  keras  import  ops 
5+ 
6+ from  keras_rs .src  import  types 
7+ from  keras_rs .src .metrics .utils  import  standardize_call_inputs_ranks 
8+ from  keras_rs .src .api_export  import  keras_rs_export 
9+ from  keras_rs .src .metrics .ranking_metrics_utils  import  sort_by_scores 
10+ 
11+ 
12+ @keras_rs_export ("keras_rs.losses.ListMLELoss" ) 
13+ class  ListMLELoss (keras .losses .Loss ):
14+     """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking. 
15+ 
16+     ListMLE loss is a listwise ranking loss that maximizes the likelihood of 
17+     the ground truth ranking. It works by: 
18+     1. Sorting items by their relevance scores (labels) 
19+     2. Computing the probability of observing this ranking given the  
20+        predicted scores 
21+     3. Maximizing this likelihood (minimizing negative log-likelihood) 
22+ 
23+     The loss is computed as the negative log-likelihood of the ground truth 
24+     ranking given the predicted scores: 
25+   
26+     ``` 
27+     loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i))) 
28+     ``` 
29+ 
30+     where s_i is the predicted score for item i in the sorted order. 
31+ 
32+     Args: 
33+         temperature: Temperature parameter for scaling logits. Higher values 
34+             make the probability distribution more uniform. Defaults to 1.0. 
35+         reduction: Type of reduction to apply to the loss. In almost all cases 
36+             this should be `"sum_over_batch_size"`. Supported options are 
37+             `"sum"`, `"sum_over_batch_size"`, `"mean"`, 
38+             `"mean_with_sample_weight"` or `None`. Defaults to  
39+             `"sum_over_batch_size"`. 
40+         name: Optional name for the loss instance. 
41+         dtype: The dtype of the loss's computations. Defaults to `None`. 
42+ 
43+     Examples: 
44+         ```python 
45+         # Basic usage 
46+         loss_fn = ListMLELoss() 
47+          
48+         # With temperature scaling 
49+         loss_fn = ListMLELoss(temperature=0.5) 
50+          
51+         # Example with synthetic data 
52+         y_true = [[3, 2, 1, 0]]  # Relevance scores 
53+         y_pred = [[0.8, 0.6, 0.4, 0.2]]  # Predicted scores 
54+         loss = loss_fn(y_true, y_pred) 
55+         ``` 
56+     """ 
57+ 
58+     def  __init__ (self , temperature : float  =  1.0 , ** kwargs : Any ) ->  None :
59+         super ().__init__ (** kwargs )
60+ 
61+         if  temperature  <=  0.0 :
62+             raise  ValueError (
63+                 f"`temperature` should be a positive float. Received: " 
64+                 f"`temperature` = { temperature }  
65+             )
66+ 
67+         self .temperature  =  temperature 
68+         self ._epsilon  =  1e-10   
69+ 
70+     def  compute_unreduced_loss (
71+         self ,
72+         labels : types .Tensor ,
73+         logits : types .Tensor ,
74+         mask : types .Tensor  |  None  =  None ,
75+     ) ->  tuple [types .Tensor , types .Tensor ]:
76+         """Compute the unreduced ListMLE loss. 
77+          
78+         Args: 
79+             labels: Ground truth relevance scores of  
80+                     shape [batch_size,list_size]. 
81+             logits: Predicted scores of shape [batch_size, list_size]. 
82+             mask: Optional mask of shape [batch_size, list_size]. 
83+          
84+         Returns: 
85+             Tuple of (losses, weights) where losses has shape [batch_size, 1] 
86+             and weights has the same shape. 
87+         """ 
88+         
89+         valid_mask  =  ops .greater_equal (labels , ops .cast (0.0 , labels .dtype ))
90+         
91+         if  mask  is  not None :
92+             valid_mask  =  ops .logical_and (valid_mask , ops .cast (mask , dtype = "bool" ))
93+         
94+         num_valid_items  =  ops .sum (ops .cast (valid_mask , dtype = labels .dtype ), 
95+                                   axis = 1 , keepdims = True )
96+         
97+         batch_has_valid_items  =  ops .greater (num_valid_items , 0.0 )
98+         
99+         
100+         labels_for_sorting  =  ops .where (valid_mask , labels , ops .full_like (labels , - 1e9 ))
101+         logits_masked  =  ops .where (valid_mask , logits , ops .full_like (logits , - 1e9 ))
102+     
103+         sorted_logits , sorted_valid_mask  =  sort_by_scores (
104+             tensors_to_sort = [logits_masked , valid_mask ],
105+             scores = labels_for_sorting , 
106+             mask = None ,  
107+             shuffle_ties = False ,  
108+             seed = None 
109+         )
110+         
111+         sorted_logits  =  ops .divide (
112+             sorted_logits ,
113+             ops .cast (self .temperature , dtype = sorted_logits .dtype )
114+         )
115+         
116+         valid_logits_for_max  =  ops .where (sorted_valid_mask , sorted_logits , 
117+                                        ops .full_like (sorted_logits , - 1e9 ))
118+         raw_max  =  ops .max (valid_logits_for_max , axis = 1 , keepdims = True )
119+         raw_max  =  ops .where (batch_has_valid_items , raw_max , ops .zeros_like (raw_max ))
120+         sorted_logits  =  sorted_logits  -  raw_max 
121+ 
122+         exp_logits  =  ops .exp (sorted_logits )
123+         exp_logits  =  ops .where (sorted_valid_mask , exp_logits , ops .zeros_like (exp_logits ))
124+         
125+         reversed_exp  =  ops .flip (exp_logits , axis = 1 )
126+         reversed_cumsum  =  ops .cumsum (reversed_exp , axis = 1 )
127+         cumsum_from_right  =  ops .flip (reversed_cumsum , axis = 1 )
128+         
129+         log_normalizers  =  ops .log (cumsum_from_right  +  self ._epsilon )
130+         log_probs  =  sorted_logits  -  log_normalizers 
131+         
132+         log_probs  =  ops .where (sorted_valid_mask , log_probs , ops .zeros_like (log_probs ))
133+         
134+         negative_log_likelihood  =  - ops .sum (log_probs , axis = 1 , keepdims = True )
135+         
136+         negative_log_likelihood  =  ops .where (batch_has_valid_items , negative_log_likelihood , 
137+                                           ops .zeros_like (negative_log_likelihood ))
138+         
139+         weights  =  ops .ones_like (negative_log_likelihood )
140+         
141+         return  negative_log_likelihood , weights 
142+ 
143+     def  call (
144+         self ,
145+         y_true : types .Tensor ,
146+         y_pred : types .Tensor ,
147+     ) ->  types .Tensor :
148+         """Compute the ListMLE loss. 
149+ 
150+         Args: 
151+             y_true: tensor or dict. Ground truth values. If tensor, of shape 
152+                 `(list_size)` for unbatched inputs or `(batch_size, list_size)` 
153+                 for batched inputs. If an item has a label of -1, it is ignored 
154+                 in loss computation. If it is a dictionary, it should have two 
155+                 keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore 
156+                 elements in loss computation. 
157+             y_pred: tensor. The predicted values, of shape `(list_size)` for 
158+                 unbatched inputs or `(batch_size, list_size)` for batched 
159+                 inputs. Should be of the same shape as `y_true`. 
160+ 
161+         Returns: 
162+             The loss tensor of shape [batch_size]. 
163+         """ 
164+         mask  =  None 
165+         if  isinstance (y_true , dict ):
166+             if  "labels"  not  in y_true :
167+                 raise  ValueError (
168+                     '`"labels"` should be present in `y_true`. Received: ' 
169+                     f"`y_true` = { y_true }  
170+                 )
171+ 
172+             mask  =  y_true .get ("mask" , None )
173+             y_true  =  y_true ["labels" ]
174+ 
175+         y_true  =  ops .convert_to_tensor (y_true )
176+         y_pred  =  ops .convert_to_tensor (y_pred )
177+         if  mask  is  not None :
178+             mask  =  ops .convert_to_tensor (mask )
179+ 
180+         y_true , y_pred , mask , _  =  standardize_call_inputs_ranks (
181+             y_true , y_pred , mask 
182+         )
183+ 
184+         losses , weights  =  self .compute_unreduced_loss (
185+             labels = y_true , logits = y_pred , mask = mask 
186+         )
187+         losses  =  ops .multiply (losses , weights )
188+         losses  =  ops .squeeze (losses , axis = - 1 )  
189+         return  losses 
190+ 
191+     def  get_config (self ) ->  dict [str , Any ]:
192+         config : dict [str , Any ] =  super ().get_config ()
193+         config .update ({"temperature" : self .temperature })
194+         return  config 
0 commit comments