diff --git a/lmms_eval/tasks/stare/_default_template_yaml b/lmms_eval/tasks/stare/_default_template_yaml new file mode 100644 index 000000000..6c2eff4b5 --- /dev/null +++ b/lmms_eval/tasks/stare/_default_template_yaml @@ -0,0 +1,28 @@ +dataset_path: pangyyyyy/STARE + +generation_kwargs: + max_new_tokens: 8192 + temperature: 0.0 + top_p: 1.0 + num_beams: 1 + do_sample: false + +output_type: generate_until +doc_to_visual: !function utils.stare_doc_to_visual +doc_to_text: !function utils.stare_doc_to_text +doc_to_target: utils.stare_doc_to_target +process_results: !function utils.stare_process_results + +metric_list: + - metric: stare_score + aggregation: !function utils.stare_aggregate_results + higher_is_better: true + +dataset_kwargs: + token: True + cache_dir: STARE + force_download: true + +metadata: + strategy: CoT # ['Direct', 'CoT'] + use_lmms_judge: False \ No newline at end of file diff --git a/lmms_eval/tasks/stare/stare.yaml b/lmms_eval/tasks/stare/stare.yaml new file mode 100644 index 000000000..dfdfa844c --- /dev/null +++ b/lmms_eval/tasks/stare/stare.yaml @@ -0,0 +1,5 @@ +# Official STARE paper use this configuration +dataset_name: All +test_split: train +task: "stare_full" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_2d_text_instruct.yaml b/lmms_eval/tasks/stare/stare_2d_text_instruct.yaml new file mode 100644 index 000000000..e1c5f2f84 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_2d_text_instruct.yaml @@ -0,0 +1,4 @@ +dataset_name: 2d_text_instruct +test_split: test +task: "stare_2d_text_instruct" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_2d_text_instruct_vsim.yaml b/lmms_eval/tasks/stare/stare_2d_text_instruct_vsim.yaml new file mode 100644 index 000000000..2ffe97373 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_2d_text_instruct_vsim.yaml @@ -0,0 +1,4 @@ +dataset_name: 2d_text_instruct_vsim +test_split: test +task: "stare_2d_text_instruct_vsim" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_2d_va.yaml b/lmms_eval/tasks/stare/stare_2d_va.yaml new file mode 100644 index 000000000..4883b4294 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_2d_va.yaml @@ -0,0 +1,4 @@ +dataset_name: 2d_va +test_split: test +task: "stare_2d_va" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_2d_va_vsim.yaml b/lmms_eval/tasks/stare/stare_2d_va_vsim.yaml new file mode 100644 index 000000000..c112aa5ec --- /dev/null +++ b/lmms_eval/tasks/stare/stare_2d_va_vsim.yaml @@ -0,0 +1,4 @@ +dataset_name: 2d_va_vsim +test_split: test +task: "stare_2d_va_vsim" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_3d_text_instruct.yaml b/lmms_eval/tasks/stare/stare_3d_text_instruct.yaml new file mode 100644 index 000000000..679376fe7 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_3d_text_instruct.yaml @@ -0,0 +1,4 @@ +dataset_name: 3d_text_instruct +test_split: test +task: "stare_3d_text_instruct" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_3d_text_instruct_vsim.yaml b/lmms_eval/tasks/stare/stare_3d_text_instruct_vsim.yaml new file mode 100644 index 000000000..7e49c5c12 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_3d_text_instruct_vsim.yaml @@ -0,0 +1,4 @@ +dataset_name: 3d_text_instruct_vsim +test_split: test +task: "stare_3d_text_instruct_vsim" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_3d_va.yaml b/lmms_eval/tasks/stare/stare_3d_va.yaml new file mode 100644 index 000000000..aaeb92b81 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_3d_va.yaml @@ -0,0 +1,4 @@ +dataset_name: 3d_va +test_split: test +task: "stare_3d_va" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_3d_va_vsim.yaml b/lmms_eval/tasks/stare/stare_3d_va_vsim.yaml new file mode 100644 index 000000000..daaa9750e --- /dev/null +++ b/lmms_eval/tasks/stare/stare_3d_va_vsim.yaml @@ -0,0 +1,4 @@ +dataset_name: 3d_va_vsim +test_split: test +task: "stare_3d_va_vsim" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_folding_nets.yaml b/lmms_eval/tasks/stare/stare_folding_nets.yaml new file mode 100644 index 000000000..8bdf9d2c2 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_folding_nets.yaml @@ -0,0 +1,4 @@ +dataset_name: folding_nets +test_split: test +task: "stare_folding_nets" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_folding_nets_vsim.yaml b/lmms_eval/tasks/stare/stare_folding_nets_vsim.yaml new file mode 100644 index 000000000..cbced3d26 --- /dev/null +++ b/lmms_eval/tasks/stare/stare_folding_nets_vsim.yaml @@ -0,0 +1,4 @@ +dataset_name: folding_nets_vsim +test_split: test +task: "stare_folding_nets_vsim" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_perspective.yaml b/lmms_eval/tasks/stare/stare_perspective.yaml new file mode 100644 index 000000000..9ad22743f --- /dev/null +++ b/lmms_eval/tasks/stare/stare_perspective.yaml @@ -0,0 +1,4 @@ +dataset_name: perspective +test_split: test +task: "stare_perspective" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_tangram_puzzle.yaml b/lmms_eval/tasks/stare/stare_tangram_puzzle.yaml new file mode 100644 index 000000000..60814de3f --- /dev/null +++ b/lmms_eval/tasks/stare/stare_tangram_puzzle.yaml @@ -0,0 +1,4 @@ +dataset_name: tangram_puzzle +test_split: test +task: "stare_tangram_puzzle" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_tangram_puzzle_vsim.yaml b/lmms_eval/tasks/stare/stare_tangram_puzzle_vsim.yaml new file mode 100644 index 000000000..67893366a --- /dev/null +++ b/lmms_eval/tasks/stare/stare_tangram_puzzle_vsim.yaml @@ -0,0 +1,4 @@ +dataset_name: tangram_puzzle_vsim +test_split: test +task: "stare_tangram_puzzle_vsim" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/stare_temporal.yaml b/lmms_eval/tasks/stare/stare_temporal.yaml new file mode 100644 index 000000000..37b37eddd --- /dev/null +++ b/lmms_eval/tasks/stare/stare_temporal.yaml @@ -0,0 +1,4 @@ +dataset_name: temporal +test_split: test +task: "stare_temporal" +include: _default_template_yaml diff --git a/lmms_eval/tasks/stare/utils.py b/lmms_eval/tasks/stare/utils.py new file mode 100644 index 000000000..3c1c16c32 --- /dev/null +++ b/lmms_eval/tasks/stare/utils.py @@ -0,0 +1,328 @@ +import logging +import os +import re +from collections import defaultdict +from pathlib import Path + +import yaml +from latex2sympy2 import latex2sympy +from sympy import simplify +from word2number import w2n + +from lmms_eval.llm_judge import get_server +from lmms_eval.llm_judge.protocol import ServerConfig + +eval_logger = logging.getLogger("lmms-eval") + +dir_name = os.path.dirname(os.path.abspath(__file__)) + + +stare_config = { + "Strategy_Instruction": {"CoT": "Please solve the problem step by step.", "Directly": "Please ensure that your output only contains the final answer without any additional content (such as intermediate reasoning steps)."}, + "multi_choice_format": '\n{question}\nAnswer with the option\'s letter from the given choices and put the letter in one "\\boxed{{}}". ', +} + +with open(Path(__file__).parent / "_default_template_yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + config = yaml.safe_load("".join(safe_data)) +hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") + +# Initialize the LLM judge server +if config["metadata"]["use_lmms_judge"]: + eval_logger.info("Using LMMS judge server for STARE task.") + API_TYPE = os.getenv("API_TYPE", "azure") # Default to azure based on .env + # For Azure OpenAI, use DEPLOYMENT_NAME as the model_name + DEPLOYMENT_NAME = os.getenv("DEPLOYMENT_NAME", "gpt-4o") + + server_config = ServerConfig( + model_name=DEPLOYMENT_NAME, # Use deployment name for Azure OpenAI + ) + server = get_server(server_name=API_TYPE, config=server_config) + + +# This is taken directly from the official STARE codebase +def build_query(sample): + """ + Construct a multiple-choice query. Inserts placeholder if missing. + Appends either chain-of-thought (CoT) or direct instruction from config. + """ + question = sample["question"] + images = sample.get("images", []) + answer = sample["answer"].strip().upper() + strategy = config["metadata"]["strategy"] + + # Ensure placeholder if images are provided + if "" not in question and images: + question += "\n" + + # Fill the template with the question + prompt_template = stare_config["multi_choice_format"] + filled_prompt = prompt_template.format(question=question) + + # Append the selected strategy instruction + if strategy == "CoT": + query = f"{filled_prompt}\n{stare_config['Strategy_Instruction']['CoT']}" + else: + query = f"{filled_prompt}\n{stare_config['Strategy_Instruction']['Directly']}" + + # Build the result dictionary + res_dict = {"query": query, "gt_content": answer, **sample} + return res_dict + + +def stare_doc_to_text(doc): + """Extracts the text prompt from a STARE dataset sample. + + Args: + doc (dict): A STARE dataset sample dictionary. + + Returns: + str: The input prompt text. + """ + res_dict = build_query(doc) + return res_dict["query"] + + +def stare_doc_to_visual(doc): + """Extracts and converts visual data from a STARE dataset sample. + + This function iterates over the 'images' key in the dataset sample, calling + the .convert("RGB") method on each item (presumed to be a PIL/Pillow Image). + + Args: + doc (dict): A STARE dataset sample dictionary. + + Returns: + list: A list of visual elements (e.g., PIL Images) converted to 'RGB' format. + """ + try: + [visual.convert("RGB") for visual in doc["images"]] + except: + print("Not successful.") + print(doc["qid"]) + return [visual.convert("RGB") for visual in doc["images"]] + + +def stare_process_results(doc, results): + key_name = "stare_score" + for pred in results: + res_dict = build_query(doc) + gt = doc["answer"] + query = res_dict["query"] + + if config["metadata"]["use_lmms_judge"]: + # Use LMM judge to evaluate the prediction + print("doc", doc) + print("\n") + print("pred", pred) + submit_prompt = create_test_prompt(score_demo_prompt, doc, pred) + try: + # Create a Request object for the unified judge API + from lmms_eval.llm_judge.protocol import Request + + request = Request(messages=[{"role": "user", "content": submit_prompt}], config=server_config) + + # Send the request to the LMM judge server + judge_response_obj = server.evaluate(request) + judge_response = judge_response_obj.content + judge_result = judge_response.strip().lower() + + # Parse the judge result to determine correctness + is_correct = "correct" in judge_result and "incorrect" not in judge_result + + stare_submission = {"id": doc["qid"], "query": query, "gt_content": gt, "pred": pred, "category": doc["category"], "judge_response": judge_response, "is_correct": is_correct} + + except Exception as e: + eval_logger.error(f"Error using LMM judge: {e}") + # Fallback to fast_extract_answer if judge fails + pred_extracted = fast_extract_answer(pred) + is_correct = is_equal(pred_extracted, gt) + + stare_submission = {"id": doc["qid"], "query": query, "gt_content": gt, "pred": pred, "category": doc["category"], "judge_error": str(e), "is_correct": is_correct} + + else: + # for no lmms judge, use fast_extract_answer only + pred = fast_extract_answer(pred) + stare_submission = {"id": doc["qid"], "query": query, "gt_content": gt, "pred": pred, "category": doc["category"], "is_correct": is_equal(pred, gt)} + return {key_name: stare_submission} + + +def stare_aggregate_results(results): + category_to_eval_samples = defaultdict(list) + total_samples = len(results) + total_correct = 0 + + for sample in results: + category = sample["category"] + + # Check if using LMM judge results or traditional evaluation + if "is_correct" in sample: + # Use LMM judge result + is_correct = sample["is_correct"] + else: + # Use traditional evaluation method + is_correct = is_equal(sample["pred"], sample["gt_content"]) + + if is_correct: + total_correct += 1 + category_to_eval_samples[category].append(1) + else: + category_to_eval_samples[category].append(0) + + accuracy = total_correct / total_samples if total_samples > 0 else 0 + category_accuracies = {category: sum(scores) / len(scores) for category, scores in category_to_eval_samples.items()} + print(f"{'Total Samples':<20}: {total_samples}") + print(f"{'Total Correct':<20}: {total_correct}") + print(f"{'Overall Accuracy':<20}: {accuracy:.4f}") + print() + + print(f"{'Per-Category Accuracy':<40}") + print("-" * 40) + for category, acc in category_accuracies.items(): + print(f"{category:<20}: {acc:.4f}") + print("=" * 40) + return accuracy + + +################################################# +# Helper functions written by official STARE repo. +################################################# + + +def extract_full_boxed_content(s): + """ + Extract the full content inside \boxed{}, handling nested braces {{}} properly. + """ + results = [] + + i = 0 + while i < len(s): + if s[i : i + 7] == r"\boxed{": + brace_stack = [] + start = i + 7 + i = start + + while i < len(s): + if s[i] == "{": + brace_stack.append(i) + elif s[i] == "}": + if brace_stack: + brace_stack.pop() + else: + results.append(s[start:i]) + break + i += 1 + i += 1 + + return results + + +def is_number(s): + try: + float(s) + return True + except ValueError: + return False + + +def is_equal(md_ans, gt_ans): + + md_ans = md_ans.lower() + gt_ans = gt_ans.lower() + + if md_ans.strip() == gt_ans.strip(): + return True + + try: + md_ans_cache = str(w2n.word_to_num(md_ans)) + if md_ans_cache.strip() == gt_ans.strip(): + return True + except ValueError: + pass + + # For Math + try: + # Parse LaTeX expressions into sympy and compare numerical values + md_sympy = latex2sympy(md_ans) + gt_sympy = latex2sympy(gt_ans) + + # Compare evaluated results, rounded to 2 decimal places + if round(float(md_sympy.evalf()), 2) == round(float(gt_sympy.evalf()), 2): + return True + + # Additionally, compare simplified symbolic expressions + if simplify(md_sympy - gt_sympy) == 0: + return True + except Exception: + pass # Ignore parsing errors or evaluation failures + + return False + + +def fast_extract_answer(response): + response = response.strip() + for ch in "ABCDEFGH": + if response.upper() == ch or response.startswith(f"{ch}:") or response.startswith(f"{ch}."): + return ch + + # Direct Strategy Open-ended + if is_number(response): + return response + + # CoT strategy + if "boxed{" in response: + try: + model_answers = extract_full_boxed_content(response) + if model_answers: + try: + text_content = re.findall(r"\\text{(.*?)}", model_answers[-1]) + if text_content: + return text_content[-1].strip() + except Exception: + pass + return model_answers[-1].strip() + except Exception: + pass + + for flag in ["final answer is", "correct answer is", "answer should be", "answer is", "answer:"]: + if flag in response.lower(): + try: + model_answer = response.lower().split(flag)[-1].strip() + return model_answer.split("\n")[0].split(".")[0] + except Exception: + pass + + return "" + + +def create_test_prompt(score_prompt, problem, pred): + score_prompt = score_prompt.strip() + response = pred + answer = problem["answer"] + full_prompt = f"{score_prompt}\n" + f"Response: {response}\n" + f"Answer: {answer}\n" + "Correct_or_not:" + return full_prompt + + +score_demo_prompt = """Please read the following example. Then determine whether the response is correct and type it +at the end of the prompt. It is worth noting that the final answer in the response is usually in \\boxed{}, +You only need to compare the final answer in the response with the answer, without considering the logical +correctness of the response itself. + +Response: The correct answer is:\n\nA + +Answer: A + +Correct_or_not: Correct + +Response: The correct option is:\n\n\\[\n\\boxed{E}\n\\] + +Answer: C + +Correct_or_not: Incorrect +"""