-
Notifications
You must be signed in to change notification settings - Fork 45
[PZ COMPETITION] UCI001(liximeng0824) #168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import re | ||
|
|
||
|
|
||
| def extract_solution(solution_str, method="strict"): | ||
| assert method in ["strict", "flexible"] | ||
|
|
||
| if method == "strict": | ||
| # Strict: extract content of the last \boxed{...} (align with evaluator's parser) | ||
| if "boxed" not in solution_str: | ||
| final_answer = None | ||
| else: | ||
| ans = solution_str.split("boxed")[-1] | ||
| if len(ans) == 0: | ||
| final_answer = None | ||
| elif ans[0] == "{": | ||
| stack = 1 | ||
| a = "" | ||
| for c in ans[1:]: | ||
| if c == "{": | ||
| stack += 1 | ||
| a += c | ||
| elif c == "}": | ||
| stack -= 1 | ||
| if stack == 0: | ||
| break | ||
| a += c | ||
| else: | ||
| a += c | ||
| final_answer = a.replace(",", "").replace("$", "") | ||
| else: | ||
| a = ans.split("$")[0].strip() | ||
| final_answer = a.replace(",", "").replace("$", "") | ||
| elif method == "flexible": | ||
| answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str) | ||
| final_answer = None | ||
| if len(answer) == 0: | ||
| # no reward is there is no answer | ||
| pass | ||
| else: | ||
| invalid_str = ["", "."] | ||
| # find the last number that is not '.' | ||
| for final_answer in reversed(answer): | ||
| if final_answer not in invalid_str: | ||
| break | ||
| return final_answer | ||
|
|
||
|
|
||
| def compute_score(solution_str, ground_truth, method="strict", format_score=0.2, score=1.0): | ||
| """GSM8K reward with \\boxed{} strict parsing and numeric equivalence. | ||
|
|
||
| Scoring policy: | ||
| - Strict format (\\boxed{...}) and numerically correct -> score (default 1.0) | ||
| - Else if strict format present (even if wrong) -> format_score (default 0.2) | ||
| - Else -> 0.0 | ||
| """ | ||
| # Extract strict (format-aware) and flexible answers | ||
| answer_strict = extract_solution(solution_str=solution_str, method="strict") | ||
| has_strict_format = answer_strict is not None | ||
|
|
||
| # Numeric equivalence helper | ||
| def _to_float(x): | ||
| try: | ||
| return float(str(x).replace(",", "").replace("$", "").strip()) | ||
| except Exception: | ||
| return None | ||
|
|
||
| def _num_equal(a, b): | ||
| av = _to_float(a) | ||
| bv = _to_float(b) | ||
| if av is not None and bv is not None: | ||
| return av == bv | ||
| return str(a).strip() == str(b).strip() | ||
|
|
||
| # 1) Strict-correct → full score | ||
| if answer_strict is not None and _num_equal(answer_strict, ground_truth): | ||
| return float(score) | ||
|
|
||
| # 2) Format-only bonus if strict format present | ||
| if has_strict_format: | ||
| return float(0.2) | ||
|
Comment on lines
+93
to
+94
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # 3) No reward | ||
| return 0.0 | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,168 @@ | ||||||||||||||||||||
| # Copyright 2024 Bytedance Ltd. and/or its affiliates | ||||||||||||||||||||
| # | ||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||
| # | ||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||
| # | ||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| Preprocess the GSM8k dataset to parquet format | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import argparse | ||||||||||||||||||||
| import os | ||||||||||||||||||||
| import re | ||||||||||||||||||||
|
|
||||||||||||||||||||
| import datasets | ||||||||||||||||||||
| from glob import glob | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # from verl.utils.hdfs_io import copy, makedirs | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| def extract_solution(solution_str): | ||||||||||||||||||||
| solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str) | ||||||||||||||||||||
| assert solution is not None | ||||||||||||||||||||
| final_solution = solution.group(0) | ||||||||||||||||||||
| final_solution = final_solution.split("#### ")[1].replace(",", "") | ||||||||||||||||||||
| return final_solution | ||||||||||||||||||||
|
Comment on lines
+29
to
+33
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 使用
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||
| parser = argparse.ArgumentParser() | ||||||||||||||||||||
| parser.add_argument("--local_dir", default="/aipilot/ai-platform/datasets/openseek_data/sft_data/Big-Math-RL-Verified-Processed_pri-mid") | ||||||||||||||||||||
| parser.add_argument("--ms_base_dir", default="/aipilot/ai-platform/datasets/openseek_data/sft_data/Big-Math-RL-Verified-Processed") | ||||||||||||||||||||
| parser.add_argument("--hdfs_dir", default=None) | ||||||||||||||||||||
| parser.add_argument("--val_ratio", type=float, default=0.05, help="Validation split ratio (0-1)") | ||||||||||||||||||||
| parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling/splitting") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ms_base_dir = args.ms_base_dir | ||||||||||||||||||||
|
|
||||||||||||||||||||
| train_candidates = glob(os.path.join(ms_base_dir, "**", "big-math-rl-verified-processed-train.arrow"), recursive=True) | ||||||||||||||||||||
| # test_candidates = glob(os.path.join(ms_base_dir, "**", "gsm8k-test.arrow"), recursive=True) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| assert len(train_candidates) > 0, f"未在 {ms_base_dir} 下找到 gsm8k-train.arrow" | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处的断言消息硬编码为
Suggested change
|
||||||||||||||||||||
| # assert len(test_candidates) > 0, f"未在 {ms_base_dir} 下找到 gsm8k-test.arrow" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| train_data_source = max(train_candidates, key=os.path.getmtime) | ||||||||||||||||||||
| # test_data_source = max(test_candidates, key=os.path.getmtime) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| print(f"[Info] train arrow: {train_data_source}") | ||||||||||||||||||||
| # print(f"[Info] test arrow: {test_data_source}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from datasets import Dataset | ||||||||||||||||||||
| train_dataset = Dataset.from_file(train_data_source) | ||||||||||||||||||||
| # test_dataset = Dataset.from_file(test_data_source) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| print(f"[Info] train rows: {train_dataset.num_rows}") | ||||||||||||||||||||
| # print(f"[Info] test rows: {test_dataset.num_rows}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| assert train_dataset.num_rows > 0, "train 数据为空" | ||||||||||||||||||||
| # assert test_dataset.num_rows > 0, "test 数据为空" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # train_dataset = dataset["train"] | ||||||||||||||||||||
| # test_dataset = dataset["test"] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # instruction_following = 'Let\'s think step by step and output the final answer after "####".' | ||||||||||||||||||||
| # instruction_following = 'Please reason step by step.\nIn the last line, write the answer after "The answer is:" and don\'t include any other text.' | ||||||||||||||||||||
| instruction_following = 'Please reason step by step, and put your final answer within \\boxed{}.\nQuestion:\n' | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 先进行过滤,避免在 map 中返回 None 导致错误 | ||||||||||||||||||||
| allowed_sources = {"orca_math", "cn_k12", "gsm8k"} | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _filter_source(ex): | ||||||||||||||||||||
| return ex.get("source") in allowed_sources | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _is_float_convertible_solution(ex): | ||||||||||||||||||||
| val = ex.get("solution") | ||||||||||||||||||||
| if val is None: | ||||||||||||||||||||
| return False | ||||||||||||||||||||
| try: | ||||||||||||||||||||
| float(str(val).strip().replace(",", "")) | ||||||||||||||||||||
| return True | ||||||||||||||||||||
| except Exception: | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||
| return False | ||||||||||||||||||||
|
|
||||||||||||||||||||
| print(f"[Info] before filter rows: {train_dataset.num_rows}") | ||||||||||||||||||||
| train_dataset = train_dataset.filter(_filter_source) | ||||||||||||||||||||
| print(f"[Info] after source filter (source in {sorted(list(allowed_sources))}) rows: {train_dataset.num_rows}") | ||||||||||||||||||||
| train_dataset = train_dataset.filter(_is_float_convertible_solution) | ||||||||||||||||||||
| print(f"[Info] after float-convertible solution filter rows: {train_dataset.num_rows}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # add a row to each data item that represents a unique id | ||||||||||||||||||||
| def make_map_fn(split): | ||||||||||||||||||||
| def process_fn(example, idx): | ||||||||||||||||||||
| # print(example) | ||||||||||||||||||||
| question_raw = example.pop("prompt") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| question = instruction_following + question_raw | ||||||||||||||||||||
|
|
||||||||||||||||||||
| answer_raw = example.pop("solution") | ||||||||||||||||||||
| solution = answer_raw | ||||||||||||||||||||
| data = { | ||||||||||||||||||||
| # "data_source": data_source, | ||||||||||||||||||||
| "data_source": "openai/gsm8k", | ||||||||||||||||||||
| "prompt": [ | ||||||||||||||||||||
| { | ||||||||||||||||||||
| "role": "user", | ||||||||||||||||||||
| "content": question, | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ], | ||||||||||||||||||||
| "ability": "math", | ||||||||||||||||||||
| "reward_model": {"style": "rule", "ground_truth": solution}, | ||||||||||||||||||||
| "extra_info": { | ||||||||||||||||||||
| "split": split, | ||||||||||||||||||||
| "index": idx, | ||||||||||||||||||||
| "answer": answer_raw, | ||||||||||||||||||||
| "question": question_raw, | ||||||||||||||||||||
| }, | ||||||||||||||||||||
| } | ||||||||||||||||||||
| return data | ||||||||||||||||||||
|
|
||||||||||||||||||||
| return process_fn | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 首先执行统一的 map 处理,随后再进行随机划分 | ||||||||||||||||||||
| processed_dataset = train_dataset.map(function=make_map_fn("trainval"), with_indices=True) | ||||||||||||||||||||
| print(f"[Info] processed rows: {processed_dataset.num_rows}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 随机划分 train/validation(默认 10% 为验证集),保证可复现性 | ||||||||||||||||||||
| split_result = processed_dataset.train_test_split(test_size=args.val_ratio, seed=args.seed, shuffle=True) | ||||||||||||||||||||
| train_dataset = split_result["train"] | ||||||||||||||||||||
| val_dataset = split_result["test"] | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 将 extra_info.split 字段分别标注为 train/validation | ||||||||||||||||||||
| def _set_split_field(split_name): | ||||||||||||||||||||
| def _fn(ex): | ||||||||||||||||||||
| extra = dict(ex.get("extra_info", {})) | ||||||||||||||||||||
| extra["split"] = split_name | ||||||||||||||||||||
| return {"extra_info": extra} | ||||||||||||||||||||
| return _fn | ||||||||||||||||||||
|
|
||||||||||||||||||||
| train_dataset = train_dataset.map(_set_split_field("train")) | ||||||||||||||||||||
| val_dataset = val_dataset.map(_set_split_field("test")) | ||||||||||||||||||||
| print(f"[Info] split rows -> train: {train_dataset.num_rows}, val: {val_dataset.num_rows}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| local_dir = args.local_dir | ||||||||||||||||||||
| hdfs_dir = args.hdfs_dir | ||||||||||||||||||||
|
|
||||||||||||||||||||
| os.makedirs(local_dir, exist_ok=True) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # 分别输出 train/validation 两个 parquet 文件 | ||||||||||||||||||||
| base_name = "big-math-rl-verified-processed_orca_cnk12_gsm8k_newprompt_2" | ||||||||||||||||||||
| train_output = os.path.join(local_dir, f"{base_name}_train.parquet") | ||||||||||||||||||||
| val_output = os.path.join(local_dir, f"{base_name}_val.parquet") | ||||||||||||||||||||
| train_dataset.to_parquet(train_output) | ||||||||||||||||||||
| val_dataset.to_parquet(val_output) | ||||||||||||||||||||
| print(f"[Info] saved -> train: {train_output}\n[Info] saved -> val : {val_output}") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if hdfs_dir is not None: | ||||||||||||||||||||
| makedirs(hdfs_dir) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| copy(src=local_dir, dst=hdfs_dir) | ||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数
compute_score的参数method在函数体内没有被使用,因为extract_solution总是以method="strict"的方式被调用。建议移除这个未使用的参数,以简化函数签名并避免混淆。