From 6d2dfb58b4cec4276089aeea2247b958767ba718 Mon Sep 17 00:00:00 2001 From: NourOM02 Date: Fri, 7 Feb 2025 11:20:06 +0100 Subject: [PATCH 1/2] Deal with aliased columns in the group by section --- eval/eval.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/eval/eval.py b/eval/eval.py index 43a3521..c2b460e 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -168,12 +168,20 @@ def get_all_minimal_queries(query: str) -> "list[str]": left = query[:start] column_str = ", ".join(column_tuple) right = query[end + 1 :] + g_column_str = column_str + len_tuple = len(column_tuple) + # check if the column str contains columns defined with alias AS ... + if " as " in column_str.lower() and "group by {}" in right.lower(): + g_column_str = "" + for i,column in enumerate(column_tuple): + as_index = column.lower().find(" as ")+4 + g_column_str += column[as_index:] if as_index - 3 else column + g_column_str += ", " if i != len_tuple-1 else "" # change group by size dynamically if necessary - right = right.replace("GROUP BY {}", f"GROUP BY {column_str}") + right = right.replace("GROUP BY {}", f"GROUP BY {g_column_str}") result_queries.append(left + column_str + right) return result_queries - def query_postgres_db( query: str, db_name: str, From 40c307cb32df628a4931e1f5e7fd06f94dddfa89 Mon Sep 17 00:00:00 2001 From: NourOM02 Date: Fri, 7 Feb 2025 17:40:15 +0100 Subject: [PATCH 2/2] Fix Black formatting in PR --- eval/eval.py | 7 ++++--- runners/anthropic_runner.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/eval/eval.py b/eval/eval.py index c2b460e..785fe6f 100644 --- a/eval/eval.py +++ b/eval/eval.py @@ -173,15 +173,16 @@ def get_all_minimal_queries(query: str) -> "list[str]": # check if the column str contains columns defined with alias AS ... if " as " in column_str.lower() and "group by {}" in right.lower(): g_column_str = "" - for i,column in enumerate(column_tuple): - as_index = column.lower().find(" as ")+4 + for i, column in enumerate(column_tuple): + as_index = column.lower().find(" as ") + 4 g_column_str += column[as_index:] if as_index - 3 else column - g_column_str += ", " if i != len_tuple-1 else "" + g_column_str += ", " if i != len_tuple - 1 else "" # change group by size dynamically if necessary right = right.replace("GROUP BY {}", f"GROUP BY {g_column_str}") result_queries.append(left + column_str + right) return result_queries + def query_postgres_db( query: str, db_name: str, diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index aa9f571..0779bc8 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -15,6 +15,7 @@ from utils.llm import chat_anthropic import json + def generate_prompt( prompt_file, question,