Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions cookbooks/training_judge_model/grpo/grader_rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,14 @@ def is_prompt_valid(doc):
if prompt is None and prompt_key in doc:
prompt = doc[prompt_key]

# Keep samples where prompt can't be extracted (safer than dropping data)
# Fallback to top-level 'query' field (common in new training data format)
if prompt is None and "query" in doc and isinstance(doc["query"], str):
prompt = doc["query"]

# Log warning if prompt cannot be extracted (for debugging)
if not prompt or not isinstance(prompt, str):
return True
logger.warning(f"Cannot extract prompt from doc keys: {list(doc.keys())}")
return True # Keep sample as fallback

# Check token length
return len(tokenizer.encode(prompt)) <= max_length
Expand Down Expand Up @@ -406,8 +411,8 @@ def _build_messages(self, example: dict) -> List[dict]:
"""Build chat messages from example - Pointwise mode with text format only."""
messages = []
# Check if example has 'query' directly at top level
if "query" in example and isinstance(example["query"], str) and example["query"]:
query = example["query"]
if "query" in example:
query = example.get("query", "")
messages.append({"role": "user", "content": query})
# Check if example has 'input' key with nested structure
elif "input" in example and isinstance(example["input"], dict) and "query" in example["input"]:
Expand Down Expand Up @@ -611,7 +616,7 @@ def _format_grader_template(self, messages: List[dict], example: dict, grader_pr
query = ""

# Check if example has fields directly at top level
if "query" in example and isinstance(example["query"], str):
if "query" in example:
query = example.get("query", "")
if "context" in example:
context = example.get("context", "")
Expand All @@ -625,10 +630,10 @@ def _format_grader_template(self, messages: List[dict], example: dict, grader_pr
except (json.JSONDecodeError, TypeError, Exception):
pass
elif isinstance(context, dict):
context = context.get("task_context", "")
tool_definitions = context.get("tool_definitions", "")
history = context.get("history", "")
reference_response = example.get("reference_response", "")
context = context.get("task_context", "")
reference_response = example.get("reference", "")
# Extract fields directly from example top level
response = example.get("response", "")
tool_calls = example.get("tool_calls", "")
Expand All @@ -644,9 +649,9 @@ def _format_grader_template(self, messages: List[dict], example: dict, grader_pr
if context:
if isinstance(context, dict):
# Extract fields directly if context is already a dictionary
context = context.get("task_context", "")
tool_definitions = context.get("tool_definitions", "")
history = context.get("history", "")
context = context.get("task_context", "")
elif isinstance(context, str):
try:
# Attempt to parse JSON string into a dictionary
Expand Down Expand Up @@ -736,11 +741,12 @@ def _extract_ground_truth(self, row_dict):
"""Extract pointwise ground truth label with configurable fields."""
try:
score_value = 0
# Check if it's the new JSON structure
if "input" in row_dict and isinstance(row_dict["input"], dict) and "query" in row_dict["input"]:
# New JSON format - extract score value
if "score" in row_dict:
score_value = row_dict["score"]
# Check if score is directly at top level of row_dict
if "score" in row_dict and (
"query" in row_dict
or ("input" in row_dict and isinstance(row_dict["input"], dict) and "query" in row_dict["input"])
):
score_value = row_dict["score"]
else:
# Old format - use original logic
output_key = self.dataset_config.output_field
Expand Down
9 changes: 2 additions & 7 deletions openjudge/runner/grading_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ def __init__(
self.show_progress = show_progress
self.executor = executor or SemaphoreResourceExecutor(max_concurrency)
self.enable_timing = enable_timing or timing_collector is not None
self.timing_collector = timing_collector or (
TimingCollector() if self.enable_timing else None
)
self.timing_collector = timing_collector or (TimingCollector() if self.enable_timing else None)

# Handle aggregators
if not aggregators:
Expand Down Expand Up @@ -438,10 +436,7 @@ async def arun(
grader_results[aggregator_name] = [None] * len(dataset)
for i in range(len(dataset)):
grader_results[aggregator_name][i] = aggregator(
{
grader_name: grader_results[grader_name][i]
for grader_name in self.grader_configs.keys()
},
{grader_name: grader_results[grader_name][i] for grader_name in self.grader_configs.keys()},
)
return grader_results

Expand Down