Skip to content

Commit 157edd5

Browse files
authored
feat: update the grader training data format to get query and other f… (#161)
* feat: update the grader training data format to get query and other fields directly from the top * feat: fix test_minimax_chat_model
1 parent 08fb1e0 commit 157edd5

File tree

2 files changed

+73
-49
lines changed

2 files changed

+73
-49
lines changed

cookbooks/training_judge_model/grpo/grader_rl_dataset.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -405,21 +405,14 @@ def _parse_config(self, config: Union[DictConfig, Dict[str, Any]]):
405405
def _build_messages(self, example: dict) -> List[dict]:
406406
"""Build chat messages from example - Pointwise mode with text format only."""
407407
messages = []
408-
409-
# Check if it's the new JSON structure (has 'input' key with nested structure)
410-
if "input" in example and isinstance(example["input"], dict) and "query" in example["input"]:
411-
# New JSON format
408+
# Check if example has 'query' directly at top level
409+
if "query" in example and isinstance(example["query"], str) and example["query"]:
410+
query = example["query"]
411+
messages.append({"role": "user", "content": query})
412+
# Check if example has 'input' key with nested structure
413+
elif "input" in example and isinstance(example["input"], dict) and "query" in example["input"]:
412414
query = example["input"].get("query", "")
413-
if query:
414-
messages.append({"role": "user", "content": query})
415-
416-
# Get chosen response (positive example)
417-
if "chosen" in example and isinstance(example["chosen"], dict):
418-
response_data = example["chosen"].get("response", {})
419-
if isinstance(response_data, dict):
420-
response_content = response_data.get("content", "")
421-
if response_content:
422-
messages.append({"role": "assistant", "content": response_content})
415+
messages.append({"role": "user", "content": query})
423416
else:
424417
# Old format - handle standard structure
425418
messages = self._build_old_format_messages(example)
@@ -615,8 +608,37 @@ def _format_grader_template(self, messages: List[dict], example: dict, grader_pr
615608
memory = ""
616609
action = ""
617610
reflection = ""
618-
if "input" in example and isinstance(example["input"], dict) and "query" in example["input"]:
619-
# New JSON format
611+
query = ""
612+
613+
# Check if example has fields directly at top level
614+
if "query" in example and isinstance(example["query"], str):
615+
query = example.get("query", "")
616+
if "context" in example:
617+
context = example.get("context", "")
618+
if isinstance(context, str):
619+
try:
620+
parsed_data = json.loads(context)
621+
if isinstance(parsed_data, dict):
622+
context = parsed_data.get("task_context", "")
623+
tool_definitions = parsed_data.get("tool_definitions", "")
624+
history = parsed_data.get("history", "")
625+
except (json.JSONDecodeError, TypeError, Exception):
626+
pass
627+
elif isinstance(context, dict):
628+
context = context.get("task_context", "")
629+
tool_definitions = context.get("tool_definitions", "")
630+
history = context.get("history", "")
631+
reference_response = example.get("reference_response", "")
632+
# Extract fields directly from example top level
633+
response = example.get("response", "")
634+
tool_calls = example.get("tool_calls", "")
635+
tool_responses = example.get("tool_responses", "")
636+
plan = example.get("plan", "")
637+
observation = example.get("observation", "")
638+
memory = example.get("memory", "")
639+
action = example.get("action", "")
640+
reflection = example.get("reflection", "")
641+
elif "input" in example and isinstance(example["input"], dict) and "query" in example["input"]:
620642
query = example["input"].get("query", "")
621643
context = example["input"].get("context", "")
622644
if context:

cookbooks/training_judge_model/grpo/pointwise/utils/preprocess_grader_data.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -136,51 +136,53 @@ def process_single_file(data_file: str, split_ratio: float, seed: int, sample_nu
136136
is_bin = True
137137

138138
try:
139-
if (
140-
item["chosen"]
141-
and item["chosen"]["response"]
142-
and "tool_calls" in item["chosen"]["response"]
143-
and isinstance(item["chosen"]["response"].get("tool_calls", []), list)
144-
):
145-
item["chosen"]["response"]["tool_calls"] = json.dumps(item["chosen"]["response"]["tool_calls"])
146-
147-
if (
148-
item["rejected"]
149-
and item["rejected"]["response"]
150-
and "tool_calls" in item["rejected"]["response"]
151-
and isinstance(item["rejected"]["response"].get("tool_calls", []), list)
152-
):
153-
item["rejected"]["response"]["tool_calls"] = json.dumps(item["rejected"]["response"]["tool_calls"])
139+
# Process chosen response
140+
if item["chosen"] and item["chosen"].get("response"):
141+
chosen_response = item["chosen"]["response"]
142+
if "tool_calls" in chosen_response and isinstance(chosen_response.get("tool_calls", []), list):
143+
chosen_response["tool_calls"] = json.dumps(chosen_response["tool_calls"])
144+
if "content" in chosen_response:
145+
chosen_response["response"] = chosen_response.pop("content")
146+
147+
# Process rejected response
148+
if item["rejected"] and item["rejected"].get("response"):
149+
rejected_response = item["rejected"]["response"]
150+
if "tool_calls" in rejected_response and isinstance(rejected_response.get("tool_calls", []), list):
151+
rejected_response["tool_calls"] = json.dumps(rejected_response["tool_calls"])
152+
if "content" in rejected_response:
153+
rejected_response["response"] = rejected_response.pop("content")
154154

155155
if item["input"] and item["input"].get("context", "") and not isinstance(item["input"]["context"], str):
156156
item["input"]["context"] = json.dumps(item["input"]["context"])
157157
except Exception as e:
158158
raise e
159+
160+
# Create new_item with response key removed
159161
if "chosen" not in item or not item["chosen"]:
160162
print(f"Warning: Missing chosen answer in item {item} from {data_file}. Skipping item.")
161163
else:
162-
output_data.append(
163-
{
164-
"input": item["input"],
165-
"answer": item["chosen"],
166-
"label": 1, # positive example
167-
"score": 1.0 if is_bin else 5.0,
168-
"task_type": task_type,
169-
}
170-
)
164+
chosen_response = item["chosen"].get("response", item["chosen"])
165+
new_item_chosen = {
166+
**item["input"],
167+
**chosen_response,
168+
"label": 1, # positive example
169+
"score": 1.0 if is_bin else 5.0,
170+
"task_type": task_type,
171+
}
172+
output_data.append(new_item_chosen)
171173

172174
if "rejected" not in item or not item["rejected"]:
173175
print(f"Warning: Missing rejected answer in item {item} from {data_file}. Skipping item.")
174176
else:
175-
output_data.append(
176-
{
177-
"input": item["input"],
178-
"answer": item["rejected"],
179-
"label": 0, # negative example
180-
"score": 0.0 if is_bin else 1.0,
181-
"task_type": task_type,
182-
}
183-
)
177+
rejected_response = item["rejected"].get("response", item["rejected"])
178+
new_item_rejected = {
179+
**item["input"],
180+
**rejected_response,
181+
"label": 0, # negative example
182+
"score": 0.0 if is_bin else 1.0,
183+
"task_type": task_type,
184+
}
185+
output_data.append(new_item_rejected)
184186
except KeyError as e:
185187
print(f"Error: Missing required key {e} in file {data_file}")
186188
return False

0 commit comments

Comments
 (0)