diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index 2081afb..76c108c 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -189,7 +189,7 @@ def run_anthropic_eval(args): db_name = row["db_name"] db_type = row["db_type"] try: - is_correct = compare_query_results( + exact_match, correct = compare_query_results( query_gold=expected_query, query_gen=query_gen, db_name=db_name, @@ -199,7 +199,7 @@ def run_anthropic_eval(args): query_category=row["query_category"], decimal_points=args.decimal_points, ) - if is_correct: + if correct: total_correct += 1 row["is_correct"] = 1 row["error_msg"] = "" diff --git a/runners/openai_runner.py b/runners/openai_runner.py index 5d207ef..2c35194 100644 --- a/runners/openai_runner.py +++ b/runners/openai_runner.py @@ -23,10 +23,7 @@ def generate_prompt( db_type, instructions="", k_shot_prompt="", - glossary="", table_metadata_string="", - prev_invalid_sql="", - prev_error_msg="", public_data=True, shuffle=True, ): @@ -88,6 +85,7 @@ def generate_prompt( table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, ) + return prompt @@ -100,10 +98,7 @@ def process_row(row, model_name, args): db_type=args.db_type, instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], - glossary=row["glossary"], table_metadata_string=row["table_metadata_string"], - prev_invalid_sql=row["prev_invalid_sql"], - prev_error_msg=row["prev_error_msg"], public_data=not args.use_private_data, shuffle=args.shuffle_metadata, ) @@ -195,7 +190,7 @@ def run_openai_eval(args): db_name = row["db_name"] db_type = row["db_type"] try: - is_correct = compare_query_results( + exact_match, correct = compare_query_results( query_gold=expected_query, query_gen=query_gen, db_name=db_name, @@ -204,7 +199,7 @@ def run_openai_eval(args): query_category=row["query_category"], db_creds=db_creds_all[db_type], ) - if is_correct: + if correct: total_correct += 1 row["is_correct"] = 1 row["error_msg"] = "" @@ -212,6 +207,7 @@ def run_openai_eval(args): row["is_correct"] = 0 row["error_msg"] = "INCORRECT RESULTS" except Exception as e: + row["is_correct"] = 0 row["error_db_exec"] = 1 row["error_msg"] = f"EXECUTION ERROR: {str(e)}" output_rows.append(row)