feat(mlxlm): per-record prompts for dataset-style agent tasks (GSM8K STaR)#1065
Merged
Conversation
…SM8K) The mlxlm SFT path shared one scenario-level task_prompt across all records, which fits single-task scenarios but not dataset-style agent tasks (GSM8K: each record solves a DIFFERENT problem). records_to_completions now uses a record's own 'prompt' field when present (falling back to the scenario task_prompt), so each completion trains on its own instruction. This is the missing piece for running STaR / ReST-EM over a problem distribution through run_mlxlm_training. Tests: per-record prompt routing + fallback; end-to-end through write_completion_dataset.
… + prompt-aware dedupe (review #1065) [P2] Per-record prompts only helped seed records: _assess_mlxlm collected {strategy, score} without the prompt, and samples_to_records dropped it, so later ReST-EM rounds trained generated GSM8K-style answers against the fallback scenario prompt, not the problem they were scored on. Now: - _assess_mlxlm draws a problem PER sample for agent tasks (so the loop explores the dataset) and collects {prompt, strategy, score} -- the exact problem each answer was scored against. - samples_to_records preserves the per-sample prompt (single-task samples carry none, unaffected). [P2] Curation deduped by strategy only, so two different problems sharing a completion/answer collapsed under dedupe=True (run_self_improving_loop's default). _strategy_key now includes the record's prompt when present: dataset records key on (problem, answer); games/single-task key on strategy alone (prior behavior preserved). Tests: samples_to_records preserves/omits prompt; dedupe keeps distinct problems with the same answer but still collapses same-problem duplicates.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
The
mlxlmSFT path shared one scenario-leveltask_promptacross all training records. That fits single-task scenarios (cap-sets, grid_ctf) but not dataset-style agent tasks like GSM8K, where each record solves a different problem.records_to_completionsnow uses a record's own"prompt"field when present (falling back to the scenariotask_prompt), so each completion trains on its own instruction. This is the missing piece for running STaR / ReST-EM over a problem distribution throughrun_mlxlm_training.Tests: per-record prompt routing + fallback; end-to-end through
write_completion_dataset.Motivation + honest finding
This was built to run a STaR self-improvement loop on GSM8K locally (Qwen2.5-0.5B-4bit): the model samples solutions to train problems, the exact-integer verifier keeps the correct chains, LoRA-SFT on them, repeat; eval on a disjoint held-out test split.
The wiring works end to end (the per-problem prompts thread correctly through training). But the result was a clean negative: held-out accuracy regressed 25% → 15% → 12.5% over two rounds. The likely causes are well-understood and point to scale, not a bug:
So this PR ships the reusable infrastructure (per-record prompts) that the STaR attempt validated, independent of the negative local result. A positive GSM8K self-improvement demo needs the regime STaR actually requires (a 3B+ base, thousands of problems, gentler training, false-positive filtering) — i.e. a GPU/Modal-scale run, not local 0.5B.