From a4aa5de5e5fd21618d51291881ee2fbba245b53a Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 21 Apr 2025 17:17:12 -0700 Subject: [PATCH 1/2] add more tests Signed-off-by: SumanthRH --- tests/evals/tasks/test_aime.py | 55 +++++++++++++++++++++++++ tests/evals/tasks/test_amc.py | 47 ++++++++++++++++++++++ tests/evals/tasks/test_math.py | 21 +++++++++- tests/evals/tasks/test_mmlu.py | 52 ++++++++++++++++++++++++ tests/evals/tasks/test_mmlu_pro.py | 64 ++++++++++++++++++++++++++++++ 5 files changed, 238 insertions(+), 1 deletion(-) create mode 100644 tests/evals/tasks/test_aime.py create mode 100644 tests/evals/tasks/test_amc.py create mode 100644 tests/evals/tasks/test_mmlu.py create mode 100644 tests/evals/tasks/test_mmlu_pro.py diff --git a/tests/evals/tasks/test_aime.py b/tests/evals/tasks/test_aime.py new file mode 100644 index 00000000..5e855003 --- /dev/null +++ b/tests/evals/tasks/test_aime.py @@ -0,0 +1,55 @@ +import pytest + +from skythought.evals.tasks.aime.aime_handler import AIMETaskHandler + + +class MockTaskConfig: + templating_parameters = { + "template": "Problem: {prompt}\n\nProvide a numerical answer." + } + answer_key = "answer" + question_key = "question" + + +@pytest.mark.parametrize( + "problem, response, expected", + [ + ( + { + "question": "Find the sum of the first 10 positive integers.", + "answer": "55", + }, + "The sum is 55", + True, + ), + ( + { + "question": "What is the value of (3^4 - 2^5)?", + "answer": "49", + }, + "48", + False, + ), + ], +) +def test_check_correctness(problem, response, expected): + handler = AIMETaskHandler(task_config=MockTaskConfig) + assert handler.check_correctness(problem, generation=response) == expected + + +@pytest.mark.parametrize( + "problem, expected", + [ + ( + { + "question": "Find the sum of the first 10 positive integers.", + "answer": "4", + }, + "Problem: Find the sum of the first 10 positive integers.\n\nProvide a numerical answer.", + ), + ], +) +def test_generate_prompt(problem, expected): + print(problem) + handler = AIMETaskHandler(task_config=MockTaskConfig) + assert handler.generate_prompt(problem) == expected diff --git a/tests/evals/tasks/test_amc.py b/tests/evals/tasks/test_amc.py new file mode 100644 index 00000000..b7dc7cca --- /dev/null +++ b/tests/evals/tasks/test_amc.py @@ -0,0 +1,47 @@ +import pytest + +from skythought.evals.tasks.amc23.amc23_handler import AMC23TaskHandler + + +class MockTaskConfig: + templating_parameters = { + "template": "Return the answer to the following: {question}" + } + answer_key = "answer" + question_key = "question" + choices_key = "choices" + + +@pytest.mark.parametrize( + "problem, response, expected", + [ + ( + {"question": "2+2", "answer": "4"}, + "5", + False, + ), + ( + {"question": "3* 25 percent", "answer": " 75%"}, + "My reply is $0.75.", # ignores dollar signs and normalizes percentages + True, + ), + ], +) +def test_check_correctness(problem, response, expected): + handler = AMC23TaskHandler(task_config=MockTaskConfig) + print(handler.check_correctness(problem, generation=response)) + assert handler.check_correctness(problem, generation=response) == expected + + +@pytest.mark.parametrize( + "problem, expected", + [ + ( + {"question": "What is the result of 2+2?", "answer": "4"}, + "Return the answer to the following: What is the result of 2+2?", + ), + ], +) +def test_generate_prompt(problem, expected): + handler = AMC23TaskHandler(task_config=MockTaskConfig) + assert handler.generate_prompt(problem) == expected diff --git a/tests/evals/tasks/test_math.py b/tests/evals/tasks/test_math.py index dc230fc8..bd5104b8 100644 --- a/tests/evals/tasks/test_math.py +++ b/tests/evals/tasks/test_math.py @@ -4,7 +4,9 @@ class MockTaskConfig: - templating_parameters = {"template": "{question}"} + templating_parameters = { + "template": "Return the answer to the following: {question}" + } answer_key = "answer" question_key = "question" @@ -42,3 +44,20 @@ def test_check_correctness( ): handler = MathTaskHandler(task_config=MockTaskConfig) assert handler.check_correctness(problem, generation=response) == expected + + +@pytest.mark.parametrize( + "problem, expected", + [ + ( + {"question": "What is the result of 2+2?", "answer": "4"}, + "Return the answer to the following: What is the result of 2+2?", + ), + ], +) +def test_generate_prompt( + problem, + expected, +): + handler = MathTaskHandler(task_config=MockTaskConfig) + assert handler.generate_prompt(problem) == expected diff --git a/tests/evals/tasks/test_mmlu.py b/tests/evals/tasks/test_mmlu.py new file mode 100644 index 00000000..babb6a3f --- /dev/null +++ b/tests/evals/tasks/test_mmlu.py @@ -0,0 +1,52 @@ +import pytest + +from skythought.evals.tasks.mmlu.mmlu_handler import MMLUTaskHandler + + +class MockTaskConfig: + templating_parameters = {"template": "{question}\n\nChoices:\n{choices}"} + answer_key = "answer" + question_key = "question" + choices_key = "choices" + + +@pytest.mark.parametrize( + "problem, response, expected", + [ + ( + { + "question": "What is the capital of France?", + "choices": "A) London\nB) Paris\nC) Berlin\nD) Madrid", + "answer": "B", + }, + "The answer is B) Paris", + True, + ), + ( + { + "question": "Which element has the atomic number 1?", + "choices": "A) Helium\nB) Oxygen\nC) Hydrogen\nD) Carbon", + "answer": "C", + }, + "C", + False, + ), + ], +) +def test_check_correctness(problem, response, expected): + handler = MMLUTaskHandler(task_config=MockTaskConfig) + assert handler.check_correctness(problem, generation=response) == expected + + +@pytest.mark.parametrize( + "problem, expected", + [ + ( + {"question": "What is the capital of France?", "answer": "B"}, + "What is the capital of France?\n\nChoices:\nA) London\nB) Paris\nC) Berlin\nD) Madrid", + ), + ], +) +def test_generate_prompt(problem, expected): + handler = MMLUTaskHandler(task_config=MockTaskConfig) + assert handler.generate_prompt(problem) == expected diff --git a/tests/evals/tasks/test_mmlu_pro.py b/tests/evals/tasks/test_mmlu_pro.py new file mode 100644 index 00000000..10f9d61b --- /dev/null +++ b/tests/evals/tasks/test_mmlu_pro.py @@ -0,0 +1,64 @@ +import pytest + +from skythought.evals.tasks.mmlu.mmlu_handler import MMLUProTaskHandler + + +class MockTaskConfig: + templating_parameters = {"template": "Question: {question}\n\nChoices:\n{choices}"} + answer_key = "answer" + question_key = "question" + choices_key = "choices" + context_key = "context" + + +@pytest.mark.parametrize( + "problem, response, expected", + [ + ( + { + "question": "What is the main function of the left ventricle?", + "choices": "A) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood", + "answer": "B", + "answer_index": 1, + }, + "B) Pumps blood to the body", + True, + ), + ( + { + "question": "What does GDP stand for?", + "choices": "A) Gross Domestic Product\nB) General Development Plan\nC) Global Distribution Process\nD) Geographic Data Point", + "answer": "A", + "answer_index": 0, + }, + "I think it's B", + False, + ), + ], +) +def test_check_correctness(problem, response, expected): + handler = MMLUProTaskHandler(task_config=MockTaskConfig) + assert handler.check_correctness(problem, generation=response) == expected + + +@pytest.mark.parametrize( + "problem, expected", + [ + ( + { + "question": "What is the main function of the left ventricle?", + "choices": "A) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood", + "answer": "B", + "answer_index": 1, + }, + "Question: What is the main function of the left ventricle?\n\nChoices:" + "\nA) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood", + ), + ], +) +def test_generate_prompt( + problem, + expected, +): + handler = MMLUProTaskHandler(task_config=MockTaskConfig) + assert handler.generate_prompt(problem) == expected From 0a42a960f8175679bcf64bf3f7069e9ad007598f Mon Sep 17 00:00:00 2001 From: SumanthRH Date: Mon, 21 Apr 2025 17:32:32 -0700 Subject: [PATCH 2/2] fix Signed-off-by: SumanthRH --- tests/evals/tasks/test_mmlu.py | 16 ++++++++++------ tests/evals/tasks/test_mmlu_pro.py | 13 +++++++++---- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/evals/tasks/test_mmlu.py b/tests/evals/tasks/test_mmlu.py index babb6a3f..f2dc65e5 100644 --- a/tests/evals/tasks/test_mmlu.py +++ b/tests/evals/tasks/test_mmlu.py @@ -4,7 +4,7 @@ class MockTaskConfig: - templating_parameters = {"template": "{question}\n\nChoices:\n{choices}"} + templating_parameters = {"template": "{prompt}"} answer_key = "answer" question_key = "question" choices_key = "choices" @@ -17,7 +17,7 @@ class MockTaskConfig: { "question": "What is the capital of France?", "choices": "A) London\nB) Paris\nC) Berlin\nD) Madrid", - "answer": "B", + "answer": 1, }, "The answer is B) Paris", True, @@ -26,9 +26,9 @@ class MockTaskConfig: { "question": "Which element has the atomic number 1?", "choices": "A) Helium\nB) Oxygen\nC) Hydrogen\nD) Carbon", - "answer": "C", + "answer": 2, }, - "C", + "A", False, ), ], @@ -42,8 +42,12 @@ def test_check_correctness(problem, response, expected): "problem, expected", [ ( - {"question": "What is the capital of France?", "answer": "B"}, - "What is the capital of France?\n\nChoices:\nA) London\nB) Paris\nC) Berlin\nD) Madrid", + { + "question": "What is the capital of France?", + "answer": "B", + "choices": ["London", "Paris", "Berlin", "Madrid"], + }, + "What is the capital of France?\nAnswer Choices: (A) London (B) Paris (C) Berlin (D) Madrid", ), ], ) diff --git a/tests/evals/tasks/test_mmlu_pro.py b/tests/evals/tasks/test_mmlu_pro.py index 10f9d61b..82eba007 100644 --- a/tests/evals/tasks/test_mmlu_pro.py +++ b/tests/evals/tasks/test_mmlu_pro.py @@ -4,7 +4,7 @@ class MockTaskConfig: - templating_parameters = {"template": "Question: {question}\n\nChoices:\n{choices}"} + templating_parameters = {"template": "Question: {prompt}"} answer_key = "answer" question_key = "question" choices_key = "choices" @@ -47,12 +47,17 @@ def test_check_correctness(problem, response, expected): ( { "question": "What is the main function of the left ventricle?", - "choices": "A) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood", + "options": [ + "Pumps blood to the lungs", + "Pumps blood to the body", + "Collects blood from the body", + "Stores blood", + ], "answer": "B", "answer_index": 1, }, - "Question: What is the main function of the left ventricle?\n\nChoices:" - "\nA) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood", + "Question: What is the main function of the left ventricle?\n" + "Answer Choices: (A) Pumps blood to the lungs (B) Pumps blood to the body (C) Collects blood from the body (D) Stores blood", ), ], )