diff --git a/dialdoc/models/rag/retrieval_rag_dialdoc.py b/dialdoc/models/rag/retrieval_rag_dialdoc.py index 4424f20..f1942e5 100644 --- a/dialdoc/models/rag/retrieval_rag_dialdoc.py +++ b/dialdoc/models/rag/retrieval_rag_dialdoc.py @@ -453,11 +453,11 @@ def nonlinear(a: List[int]): ids_batched = [] vectors_batched = [] scores_batched = [] - for comb_h_s, curr_h_s, hist_h_s, dom_batch in zip( + for comb_h_s, curr_h_s, hist_h_s in zip( combined_hidden_states_batched, current_hidden_states_batched, history_hidden_states_batched, - domain_batched, + # domain_batched, ): start_time = time.time() if self.config.scoring_func in ["linear", "linear2", "linear3", "nonlinear"]: @@ -619,4 +619,4 @@ def __call__( "doc_scores": doc_scores, }, tensor_type=return_tensors, - ) \ No newline at end of file + )