diff --git a/print_results.py b/print_results.py index 4ec671b..d992afe 100644 --- a/print_results.py +++ b/print_results.py @@ -47,6 +47,8 @@ def print_wraped(s: str): print_wraped(f"{answer} | {output}") print() + return clean_solution_output == clean_model_output + def print_results_for_problem(problem_path: str, verbose: bool = False): prompt_in_dir = f"{problem_path}/{Problem.dirs['prompt_in']}" @@ -69,13 +71,14 @@ def print_results_for_problem(problem_path: str, verbose: bool = False): ins = sorted(match_tests_to_prompts(ins, num_tests)) solution_outs = sorted(match_tests_to_prompts(solution_outs, num_tests)) + num_correct = 0 for ( prompt_in_filename, in_filename, model_out_filename, solution_out_filename, ) in zip(prompt_ins, ins, model_outs, solution_outs): - print_evaluation_result_for_testcase( + num_correct += print_evaluation_result_for_testcase( problem_path, open(os.path.join(prompt_in_dir, prompt_in_filename), "r").read(), open(os.path.join(in_dir, in_filename), "r").read(), @@ -83,6 +86,7 @@ def print_results_for_problem(problem_path: str, verbose: bool = False): open(os.path.join(model_out_dir, model_out_filename), "r").read(), verbose, ) + return num_correct / len(model_outs) def main(): @@ -106,10 +110,18 @@ def main(): args = parser.parse_args() if args.folder: - for problem in Problem.read_problems_from_dir(args.path): - print_results_for_problem(problem.id, args.verbose) + problems = Problem.read_problems_from_dir(args.path) + results = [] + for problem in problems: + accuracy = print_results_for_problem(problem.id, args.verbose) + results.append(accuracy) + + for problem, accuracy in sorted(zip(problems, results), key=lambda x: x[0].id): + print(f"Problem: {problem.id} Accuracy: {accuracy:.3f}") + else: - print_results_for_problem(args.path, args.verbose) + accuracy = print_results_for_problem(args.path, args.verbose) + print(f"Accuracy: {accuracy:.3f}") if __name__ == "__main__":