@@ -174,11 +174,15 @@ def forward(
174
174
[first_DMSA_attn_weights , second_DMSA_attn_weights , combining_weights ],
175
175
) = self ._process (inputs , diagonal_attention_mask )
176
176
177
+ results = {
178
+ "first_DMSA_attn_weights" : first_DMSA_attn_weights ,
179
+ "second_DMSA_attn_weights" : second_DMSA_attn_weights ,
180
+ "combining_weights" : combining_weights ,
181
+ "imputed_data" : imputed_data ,
182
+ }
177
183
if not training :
178
184
# if not in training mode, return the classification result only
179
- return {
180
- "imputed_data" : imputed_data ,
181
- }
185
+ return results
182
186
183
187
ORT_loss = 0
184
188
ORT_loss += self .customized_loss_func (X_tilde_1 , X , masks )
@@ -193,13 +197,10 @@ def forward(
193
197
# `loss` is always the item for backward propagating to update the model
194
198
loss = self .ORT_weight * ORT_loss + self .MIT_weight * MIT_loss
195
199
196
- results = {
197
- "first_DMSA_attn_weights" : first_DMSA_attn_weights ,
198
- "second_DMSA_attn_weights" : second_DMSA_attn_weights ,
199
- "combining_weights" : combining_weights ,
200
- "imputed_data" : imputed_data ,
201
- "ORT_loss" : ORT_loss ,
202
- "MIT_loss" : MIT_loss ,
203
- "loss" : loss , # will be used for backward propagating to update the model
204
- }
200
+ results ["ORT_loss" ] = ORT_loss
201
+ results ["MIT_loss" ] = MIT_loss
202
+
203
+ # will be used for backward propagating to update the model
204
+ results ["loss" ] = loss
205
+
205
206
return results
0 commit comments