diff --git a/synthetic_data_kit/core/curate.py b/synthetic_data_kit/core/curate.py index ec83ee7a..1cadce81 100644 --- a/synthetic_data_kit/core/curate.py +++ b/synthetic_data_kit/core/curate.py @@ -49,13 +49,28 @@ def curate_qa_pairs( with open(input_path, 'r', encoding='utf-8') as f: data = json.load(f) - # Extract QA pairs + # Extract QA pairs or CoT examples qa_pairs = data.get("qa_pairs", []) + cot_examples = data.get("cot_examples", []) summary = data.get("summary", "") - # If there are no QA pairs or they're already filtered + # Determine which format we're working with + is_cot_format = False + if cot_examples and not qa_pairs: + # Convert CoT examples to QA format for curation + qa_pairs = [] + for example in cot_examples: + qa_pair = { + "question": example.get("question", ""), + "answer": example.get("answer", ""), + "reasoning": example.get("reasoning", "") # Keep reasoning for reference + } + qa_pairs.append(qa_pair) + is_cot_format = True + + # If there are no QA pairs or CoT examples if not qa_pairs: - raise ValueError("No QA pairs found in the input file") + raise ValueError("No QA pairs or CoT examples found in the input file") # Initialize LLM client client = LLMClient( @@ -269,13 +284,34 @@ def curate_qa_pairs( # Convert to conversation format conversations = convert_to_conversation_format(filtered_pairs) - # Create result with filtered pairs - result = { - "summary": summary, - "qa_pairs": filtered_pairs, - "conversations": conversations, - "metrics": metrics - } + # Create result with filtered pairs in the appropriate format + if is_cot_format: + # Convert back to CoT format + filtered_cot_examples = [] + for pair in filtered_pairs: + cot_example = { + "question": pair.get("question", ""), + "reasoning": pair.get("reasoning", ""), + "answer": pair.get("answer", "") + } + # Keep rating if it exists + if "rating" in pair: + cot_example["rating"] = pair["rating"] + filtered_cot_examples.append(cot_example) + + result = { + "summary": summary, + "cot_examples": filtered_cot_examples, + "conversations": conversations, + "metrics": metrics + } + else: + result = { + "summary": summary, + "qa_pairs": filtered_pairs, + "conversations": conversations, + "metrics": metrics + } # Ensure output directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) diff --git a/tests/unit/test_error_handling.py b/tests/unit/test_error_handling.py index 21289a82..66a4ac24 100644 --- a/tests/unit/test_error_handling.py +++ b/tests/unit/test_error_handling.py @@ -158,11 +158,19 @@ def test_curate_input_validation(patch_config, test_env): { "question": "What is synthetic data?", "answer": "Synthetic data is artificially generated data.", - }, + } + ] + + cot_examples = [ + { + "question": "What is synthetic data?", + "reasoning": "Synthetic data is artificially generated data.", + "answer": "Synthetic data is artificially generated data.", + } ] with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f: - json.dump({"qa_pairs": qa_pairs}, f) + json.dump({"qa_pairs": qa_pairs, "cot_examples": cot_examples}, f) file_path = f.name # Create temporary output directory @@ -182,7 +190,7 @@ def test_curate_input_validation(patch_config, test_env): curate.curate_qa_pairs(input_path=empty_file_path, output_path=output_path) # Check that the error message is helpful - assert "No QA pairs found" in str(excinfo.value) + assert "No QA pairs or CoT examples found" in str(excinfo.value) finally: # Clean up if os.path.exists(file_path):