Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions openseek/competition/pz/UCI001/gsm8k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re


def extract_solution(solution_str, method="strict"):
assert method in ["strict", "flexible"]

if method == "strict":
# Strict: extract content of the last \boxed{...} (align with evaluator's parser)
if "boxed" not in solution_str:
final_answer = None
else:
ans = solution_str.split("boxed")[-1]
if len(ans) == 0:
final_answer = None
elif ans[0] == "{":
stack = 1
a = ""
for c in ans[1:]:
if c == "{":
stack += 1
a += c
elif c == "}":
stack -= 1
if stack == 0:
break
a += c
else:
a += c
final_answer = a.replace(",", "").replace("$", "")
else:
a = ans.split("$")[0].strip()
final_answer = a.replace(",", "").replace("$", "")
elif method == "flexible":
answer = re.findall("(\\-?[0-9\\.\\,]+)", solution_str)
final_answer = None
if len(answer) == 0:
# no reward is there is no answer
pass
else:
invalid_str = ["", "."]
# find the last number that is not '.'
for final_answer in reversed(answer):
if final_answer not in invalid_str:
break
return final_answer


def compute_score(solution_str, ground_truth, method="strict", format_score=0.2, score=1.0):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

函数 compute_score 的参数 method 在函数体内没有被使用,因为 extract_solution 总是以 method="strict" 的方式被调用。建议移除这个未使用的参数,以简化函数签名并避免混淆。

Suggested change
def compute_score(solution_str, ground_truth, method="strict", format_score=0.2, score=1.0):
def compute_score(solution_str, ground_truth, format_score=0.2, score=1.0):

"""GSM8K reward with \\boxed{} strict parsing and numeric equivalence.

Scoring policy:
- Strict format (\\boxed{...}) and numerically correct -> score (default 1.0)
- Else if strict format present (even if wrong) -> format_score (default 0.2)
- Else -> 0.0
"""
# Extract strict (format-aware) and flexible answers
answer_strict = extract_solution(solution_str=solution_str, method="strict")
has_strict_format = answer_strict is not None

# Numeric equivalence helper
def _to_float(x):
try:
return float(str(x).replace(",", "").replace("$", "").strip())
except Exception:
return None

def _num_equal(a, b):
av = _to_float(a)
bv = _to_float(b)
if av is not None and bv is not None:
return av == bv
return str(a).strip() == str(b).strip()

# 1) Strict-correct → full score
if answer_strict is not None and _num_equal(answer_strict, ground_truth):
return float(score)

# 2) Format-only bonus if strict format present
if has_strict_format:
return float(0.2)
Comment on lines +93 to +94
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

此处的返回值硬编码为 float(0.2),但函数签名中定义了 format_score 参数,并且文档字符串也提到了它。这使得 format_score 参数实际上无效。为了使函数行为与定义一致,应使用传入的 format_score 参数。

Suggested change
if has_strict_format:
return float(0.2)
if has_strict_format:
return float(format_score)


# 3) No reward
return 0.0
168 changes: 168 additions & 0 deletions openseek/competition/pz/UCI001/gsm8k_lxm2_newprompt_trainval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Preprocess the GSM8k dataset to parquet format
"""

import argparse
import os
import re

import datasets
from glob import glob

# from verl.utils.hdfs_io import copy, makedirs
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

在脚本的末尾(166-168行)调用了 makedirscopy 函数,但它们的导入语句在第25行被注释掉了。如果提供了 hdfs_dir 参数,这将导致 NameError。请取消此行的注释以修复该错误。

Suggested change
# from verl.utils.hdfs_io import copy, makedirs
from verl.utils.hdfs_io import copy, makedirs



def extract_solution(solution_str):
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution
Comment on lines +29 to +33
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

使用 solution.group(0) 获取整个匹配(例如 "#### 123"),然后再用 split 来提取数字,这种方式有点迂回且不够健壮。直接使用 solution.group(1) 可以更简洁、直接地获取正则表达式中捕获组匹配到的数字部分。同时,建议在断言失败时提供更有用的错误信息。

Suggested change
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None
final_solution = solution.group(0)
final_solution = final_solution.split("#### ")[1].replace(",", "")
return final_solution
solution = re.search("#### (\\-?[0-9\\.\\,]+)", solution_str)
assert solution is not None, f"无法在字符串中找到解决方案格式 '#### ...':{solution_str}"
final_solution = solution.group(1).replace(",", "")
return final_solution



if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default="/aipilot/ai-platform/datasets/openseek_data/sft_data/Big-Math-RL-Verified-Processed_pri-mid")
parser.add_argument("--ms_base_dir", default="/aipilot/ai-platform/datasets/openseek_data/sft_data/Big-Math-RL-Verified-Processed")
parser.add_argument("--hdfs_dir", default=None)
parser.add_argument("--val_ratio", type=float, default=0.05, help="Validation split ratio (0-1)")
parser.add_argument("--seed", type=int, default=42, help="Random seed for shuffling/splitting")

args = parser.parse_args()

ms_base_dir = args.ms_base_dir

train_candidates = glob(os.path.join(ms_base_dir, "**", "big-math-rl-verified-processed-train.arrow"), recursive=True)
# test_candidates = glob(os.path.join(ms_base_dir, "**", "gsm8k-test.arrow"), recursive=True)

assert len(train_candidates) > 0, f"未在 {ms_base_dir} 下找到 gsm8k-train.arrow"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

此处的断言消息硬编码为 gsm8k-train.arrow,但实际搜索的文件名是 big-math-rl-verified-processed-train.arrow(见第48行)。这可能会在调试时引起困惑。建议将断言消息与实际搜索的文件名保持一致。

Suggested change
assert len(train_candidates) > 0, f"未在 {ms_base_dir} 下找到 gsm8k-train.arrow"
assert len(train_candidates) > 0, f"未在 {ms_base_dir} 下找到 big-math-rl-verified-processed-train.arrow"

# assert len(test_candidates) > 0, f"未在 {ms_base_dir} 下找到 gsm8k-test.arrow"

train_data_source = max(train_candidates, key=os.path.getmtime)
# test_data_source = max(test_candidates, key=os.path.getmtime)

print(f"[Info] train arrow: {train_data_source}")
# print(f"[Info] test arrow: {test_data_source}")

from datasets import Dataset
train_dataset = Dataset.from_file(train_data_source)
# test_dataset = Dataset.from_file(test_data_source)

print(f"[Info] train rows: {train_dataset.num_rows}")
# print(f"[Info] test rows: {test_dataset.num_rows}")

assert train_dataset.num_rows > 0, "train 数据为空"
# assert test_dataset.num_rows > 0, "test 数据为空"

# train_dataset = dataset["train"]
# test_dataset = dataset["test"]

# instruction_following = 'Let\'s think step by step and output the final answer after "####".'
# instruction_following = 'Please reason step by step.\nIn the last line, write the answer after "The answer is:" and don\'t include any other text.'
instruction_following = 'Please reason step by step, and put your final answer within \\boxed{}.\nQuestion:\n'

# 先进行过滤,避免在 map 中返回 None 导致错误
allowed_sources = {"orca_math", "cn_k12", "gsm8k"}

def _filter_source(ex):
return ex.get("source") in allowed_sources

def _is_float_convertible_solution(ex):
val = ex.get("solution")
if val is None:
return False
try:
float(str(val).strip().replace(",", ""))
return True
except Exception:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

捕获通用的 Exception 会抑制所有可能的错误,这使得调试变得更加困难。建议捕获更具体的异常,例如 ValueErrorTypeError,这两种是 float() 转换失败时最可能抛出的异常。

Suggested change
except Exception:
except (ValueError, TypeError):

return False

print(f"[Info] before filter rows: {train_dataset.num_rows}")
train_dataset = train_dataset.filter(_filter_source)
print(f"[Info] after source filter (source in {sorted(list(allowed_sources))}) rows: {train_dataset.num_rows}")
train_dataset = train_dataset.filter(_is_float_convertible_solution)
print(f"[Info] after float-convertible solution filter rows: {train_dataset.num_rows}")

# add a row to each data item that represents a unique id
def make_map_fn(split):
def process_fn(example, idx):
# print(example)
question_raw = example.pop("prompt")

question = instruction_following + question_raw

answer_raw = example.pop("solution")
solution = answer_raw
data = {
# "data_source": data_source,
"data_source": "openai/gsm8k",
"prompt": [
{
"role": "user",
"content": question,
}
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw,
},
}
return data

return process_fn

# 首先执行统一的 map 处理,随后再进行随机划分
processed_dataset = train_dataset.map(function=make_map_fn("trainval"), with_indices=True)
print(f"[Info] processed rows: {processed_dataset.num_rows}")

# 随机划分 train/validation(默认 10% 为验证集),保证可复现性
split_result = processed_dataset.train_test_split(test_size=args.val_ratio, seed=args.seed, shuffle=True)
train_dataset = split_result["train"]
val_dataset = split_result["test"]

# 将 extra_info.split 字段分别标注为 train/validation
def _set_split_field(split_name):
def _fn(ex):
extra = dict(ex.get("extra_info", {}))
extra["split"] = split_name
return {"extra_info": extra}
return _fn

train_dataset = train_dataset.map(_set_split_field("train"))
val_dataset = val_dataset.map(_set_split_field("test"))
print(f"[Info] split rows -> train: {train_dataset.num_rows}, val: {val_dataset.num_rows}")

local_dir = args.local_dir
hdfs_dir = args.hdfs_dir

os.makedirs(local_dir, exist_ok=True)

# 分别输出 train/validation 两个 parquet 文件
base_name = "big-math-rl-verified-processed_orca_cnk12_gsm8k_newprompt_2"
train_output = os.path.join(local_dir, f"{base_name}_train.parquet")
val_output = os.path.join(local_dir, f"{base_name}_val.parquet")
train_dataset.to_parquet(train_output)
val_dataset.to_parquet(val_output)
print(f"[Info] saved -> train: {train_output}\n[Info] saved -> val : {val_output}")

if hdfs_dir is not None:
makedirs(hdfs_dir)

copy(src=local_dir, dst=hdfs_dir)
Loading