Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/evals/tasks/test_aime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest

from skythought.evals.tasks.aime.aime_handler import AIMETaskHandler


class MockTaskConfig:
templating_parameters = {
"template": "Problem: {prompt}\n\nProvide a numerical answer."
}
answer_key = "answer"
question_key = "question"


@pytest.mark.parametrize(
"problem, response, expected",
[
(
{
"question": "Find the sum of the first 10 positive integers.",
"answer": "55",
},
"The sum is 55",
True,
),
(
{
"question": "What is the value of (3^4 - 2^5)?",
"answer": "49",
},
"48",
False,
),
],
)
def test_check_correctness(problem, response, expected):
handler = AIMETaskHandler(task_config=MockTaskConfig)
assert handler.check_correctness(problem, generation=response) == expected


@pytest.mark.parametrize(
"problem, expected",
[
(
{
"question": "Find the sum of the first 10 positive integers.",
"answer": "4",
},
"Problem: Find the sum of the first 10 positive integers.\n\nProvide a numerical answer.",
),
],
)
def test_generate_prompt(problem, expected):
print(problem)
handler = AIMETaskHandler(task_config=MockTaskConfig)
assert handler.generate_prompt(problem) == expected
47 changes: 47 additions & 0 deletions tests/evals/tasks/test_amc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from skythought.evals.tasks.amc23.amc23_handler import AMC23TaskHandler


class MockTaskConfig:
templating_parameters = {
"template": "Return the answer to the following: {question}"
}
answer_key = "answer"
question_key = "question"
choices_key = "choices"


@pytest.mark.parametrize(
"problem, response, expected",
[
(
{"question": "2+2", "answer": "4"},
"5",
False,
),
(
{"question": "3* 25 percent", "answer": " 75%"},
"My reply is $0.75.", # ignores dollar signs and normalizes percentages
True,
),
],
)
def test_check_correctness(problem, response, expected):
handler = AMC23TaskHandler(task_config=MockTaskConfig)
print(handler.check_correctness(problem, generation=response))
assert handler.check_correctness(problem, generation=response) == expected


@pytest.mark.parametrize(
"problem, expected",
[
(
{"question": "What is the result of 2+2?", "answer": "4"},
"Return the answer to the following: What is the result of 2+2?",
),
],
)
def test_generate_prompt(problem, expected):
handler = AMC23TaskHandler(task_config=MockTaskConfig)
assert handler.generate_prompt(problem) == expected
21 changes: 20 additions & 1 deletion tests/evals/tasks/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@


class MockTaskConfig:
templating_parameters = {"template": "{question}"}
templating_parameters = {
"template": "Return the answer to the following: {question}"
}
answer_key = "answer"
question_key = "question"

Expand Down Expand Up @@ -42,3 +44,20 @@ def test_check_correctness(
):
handler = MathTaskHandler(task_config=MockTaskConfig)
assert handler.check_correctness(problem, generation=response) == expected


@pytest.mark.parametrize(
"problem, expected",
[
(
{"question": "What is the result of 2+2?", "answer": "4"},
"Return the answer to the following: What is the result of 2+2?",
),
],
)
def test_generate_prompt(
problem,
expected,
):
handler = MathTaskHandler(task_config=MockTaskConfig)
assert handler.generate_prompt(problem) == expected
56 changes: 56 additions & 0 deletions tests/evals/tasks/test_mmlu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from skythought.evals.tasks.mmlu.mmlu_handler import MMLUTaskHandler


class MockTaskConfig:
templating_parameters = {"template": "{prompt}"}
answer_key = "answer"
question_key = "question"
choices_key = "choices"


@pytest.mark.parametrize(
"problem, response, expected",
[
(
{
"question": "What is the capital of France?",
"choices": "A) London\nB) Paris\nC) Berlin\nD) Madrid",
"answer": 1,
},
"The answer is B) Paris",
True,
),
(
{
"question": "Which element has the atomic number 1?",
"choices": "A) Helium\nB) Oxygen\nC) Hydrogen\nD) Carbon",
"answer": 2,
},
"A",
False,
),
],
)
def test_check_correctness(problem, response, expected):
handler = MMLUTaskHandler(task_config=MockTaskConfig)
assert handler.check_correctness(problem, generation=response) == expected


@pytest.mark.parametrize(
"problem, expected",
[
(
{
"question": "What is the capital of France?",
"answer": "B",
"choices": ["London", "Paris", "Berlin", "Madrid"],
},
"What is the capital of France?\nAnswer Choices: (A) London (B) Paris (C) Berlin (D) Madrid",
),
],
)
def test_generate_prompt(problem, expected):
handler = MMLUTaskHandler(task_config=MockTaskConfig)
assert handler.generate_prompt(problem) == expected
69 changes: 69 additions & 0 deletions tests/evals/tasks/test_mmlu_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from skythought.evals.tasks.mmlu.mmlu_handler import MMLUProTaskHandler


class MockTaskConfig:
templating_parameters = {"template": "Question: {prompt}"}
answer_key = "answer"
question_key = "question"
choices_key = "choices"
context_key = "context"


@pytest.mark.parametrize(
"problem, response, expected",
[
(
{
"question": "What is the main function of the left ventricle?",
"choices": "A) Pumps blood to the lungs\nB) Pumps blood to the body\nC) Collects blood from the body\nD) Stores blood",
"answer": "B",
"answer_index": 1,
},
"B) Pumps blood to the body",
True,
),
(
{
"question": "What does GDP stand for?",
"choices": "A) Gross Domestic Product\nB) General Development Plan\nC) Global Distribution Process\nD) Geographic Data Point",
"answer": "A",
"answer_index": 0,
},
"I think it's B",
False,
),
],
)
def test_check_correctness(problem, response, expected):
handler = MMLUProTaskHandler(task_config=MockTaskConfig)
assert handler.check_correctness(problem, generation=response) == expected


@pytest.mark.parametrize(
"problem, expected",
[
(
{
"question": "What is the main function of the left ventricle?",
"options": [
"Pumps blood to the lungs",
"Pumps blood to the body",
"Collects blood from the body",
"Stores blood",
],
"answer": "B",
"answer_index": 1,
},
"Question: What is the main function of the left ventricle?\n"
"Answer Choices: (A) Pumps blood to the lungs (B) Pumps blood to the body (C) Collects blood from the body (D) Stores blood",
),
],
)
def test_generate_prompt(
problem,
expected,
):
handler = MMLUProTaskHandler(task_config=MockTaskConfig)
assert handler.generate_prompt(problem) == expected