diff --git a/lib/synth-kg/training_steps/score/run.py b/lib/synth-kg/training_steps/score/run.py index ea3fa48..49fdf82 100644 --- a/lib/synth-kg/training_steps/score/run.py +++ b/lib/synth-kg/training_steps/score/run.py @@ -1,5 +1,6 @@ import argparse +import numpy as np import pandas as pd from sentence_transformers import SentenceTransformer from sentence_transformers.util import cos_sim @@ -73,10 +74,17 @@ def main(): final_dataset["chosen"] = public_dataset.apply( lambda row: row[f"response_{best_response_idx[row.name]}"], axis=1 ) + final_dataset["chosen_score"] = public_dataset.apply( + lambda row: row[f"similarity_score_{best_response_idx[row.name]}"], axis=1 + ) final_dataset["rejected"] = public_dataset.apply( lambda row: row[f"response_{worst_response_idx[row.name]}"], axis=1 ) + percentile = np.percentile(final_dataset["chosen_score"], 65) + final_dataset = final_dataset.loc[final_dataset["chosen_score"] > percentile] + + final_dataset = final_dataset[final_dataset["chosen"].apply(lambda x: len(x.split()) >= 20)] final_dataset.to_parquet( f"{args.output_path}/model={args.evaluator_path.replace('/','-')}_dpo.parquet" )