From 442cad5971b32edbeecd0d154aa125433d95b7e7 Mon Sep 17 00:00:00 2001 From: Simon Meoni Date: Thu, 27 Feb 2025 16:34:38 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=97=EF=B8=8F=20Synthetic=20KG:=20add=20fi?= =?UTF-8?q?ltering=20in=20scoring=20step?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lib/synth-kg/training_steps/score/run.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/synth-kg/training_steps/score/run.py b/lib/synth-kg/training_steps/score/run.py index ea3fa48..49fdf82 100644 --- a/lib/synth-kg/training_steps/score/run.py +++ b/lib/synth-kg/training_steps/score/run.py @@ -1,5 +1,6 @@ import argparse +import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim @@ -73,10 +74,17 @@ def main(): final_dataset["chosen"] = public_dataset.apply( lambda row: row[f"response_{best_response_idx[row.name]}"], axis=1 ) + final_dataset["chosen_score"] = public_dataset.apply( + lambda row: row[f"similarity_score_{best_response_idx[row.name]}"], axis=1 + ) final_dataset["rejected"] = public_dataset.apply( lambda row: row[f"response_{worst_response_idx[row.name]}"], axis=1 ) + percentile = np.percentile(final_dataset["chosen_score"], 65) + final_dataset = final_dataset.loc[final_dataset["chosen_score"] > percentile] + + final_dataset = final_dataset[final_dataset["chosen"].apply(lambda x: len(x.split()) >= 20)] final_dataset.to_parquet( f"{args.output_path}/model={args.evaluator_path.replace('/','-')}_dpo.parquet" )