Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions synthetic_data_kit/core/curate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,28 @@ def curate_qa_pairs(
with open(input_path, 'r', encoding='utf-8') as f:
data = json.load(f)

# Extract QA pairs
# Extract QA pairs or CoT examples
qa_pairs = data.get("qa_pairs", [])
cot_examples = data.get("cot_examples", [])
summary = data.get("summary", "")

# If there are no QA pairs or they're already filtered
# Determine which format we're working with
is_cot_format = False
if cot_examples and not qa_pairs:
# Convert CoT examples to QA format for curation
qa_pairs = []
for example in cot_examples:
qa_pair = {
"question": example.get("question", ""),
"answer": example.get("answer", ""),
"reasoning": example.get("reasoning", "") # Keep reasoning for reference
}
qa_pairs.append(qa_pair)
is_cot_format = True

# If there are no QA pairs or CoT examples
if not qa_pairs:
raise ValueError("No QA pairs found in the input file")
raise ValueError("No QA pairs or CoT examples found in the input file")

# Initialize LLM client
client = LLMClient(
Expand Down Expand Up @@ -269,13 +284,34 @@ def curate_qa_pairs(
# Convert to conversation format
conversations = convert_to_conversation_format(filtered_pairs)

# Create result with filtered pairs
result = {
"summary": summary,
"qa_pairs": filtered_pairs,
"conversations": conversations,
"metrics": metrics
}
# Create result with filtered pairs in the appropriate format
if is_cot_format:
# Convert back to CoT format
filtered_cot_examples = []
for pair in filtered_pairs:
cot_example = {
"question": pair.get("question", ""),
"reasoning": pair.get("reasoning", ""),
"answer": pair.get("answer", "")
}
# Keep rating if it exists
if "rating" in pair:
cot_example["rating"] = pair["rating"]
filtered_cot_examples.append(cot_example)

result = {
"summary": summary,
"cot_examples": filtered_cot_examples,
"conversations": conversations,
"metrics": metrics
}
else:
result = {
"summary": summary,
"qa_pairs": filtered_pairs,
"conversations": conversations,
"metrics": metrics
}

# Ensure output directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
Expand Down
14 changes: 11 additions & 3 deletions tests/unit/test_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,19 @@ def test_curate_input_validation(patch_config, test_env):
{
"question": "What is synthetic data?",
"answer": "Synthetic data is artificially generated data.",
},
}
]

cot_examples = [
{
"question": "What is synthetic data?",
"reasoning": "Synthetic data is artificially generated data.",
"answer": "Synthetic data is artificially generated data.",
}
]

with tempfile.NamedTemporaryFile(mode="w+", suffix=".json", delete=False) as f:
json.dump({"qa_pairs": qa_pairs}, f)
json.dump({"qa_pairs": qa_pairs, "cot_examples": cot_examples}, f)
file_path = f.name

# Create temporary output directory
Expand All @@ -182,7 +190,7 @@ def test_curate_input_validation(patch_config, test_env):
curate.curate_qa_pairs(input_path=empty_file_path, output_path=output_path)

# Check that the error message is helpful
assert "No QA pairs found" in str(excinfo.value)
assert "No QA pairs or CoT examples found" in str(excinfo.value)
finally:
# Clean up
if os.path.exists(file_path):
Expand Down