Refactor mdlm, bd3lm, ppl tracking, and minor improvements#68
Conversation
* refactor train args * Initial plan * Fix training arguments import and inheritance chain bugs - Fix incorrect import of TrainingArguments in mdlm.py (was importing from dllm.utils.data, should be dllm.utils.configs) - Fix DreamConfig.__post_init__ to call super().__post_init__() to ensure proper inheritance chain - Fix typo in sl.py: 'arg=args' -> 'args=args' for MDLMTrainer.__init__ and BD3LMTrainer.__init__ - Fix BD3LMAnDSLConfig to inherit from BD3LMTrainer.BD3LMConfig instead of MDLMTrainer.MDLMConfig (to include block_size field) Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler from MDLMConfig to __init__ kwarg - Remove scheduler field from MDLMConfig dataclass - Add scheduler as keyword argument in MDLMTrainer.__init__ with default LinearAlphaScheduler() - Propagate scheduler parameter through BD3LMTrainer.__init__ - Propagate scheduler parameter through MDLMAnDSLTrainer.__init__ and BD3LMAnDSLTrainer.__init__ - DreamTrainer inherits the scheduler kwarg automatically from MDLMTrainer Co-authored-by: ZHZisZZ <[email protected]> * Simplify: let inherited trainers pass scheduler via **kwargs Removed unnecessary explicit scheduler parameter from inherited trainers (BD3LMTrainer, MDLMAnDSLTrainer, BD3LMAnDSLTrainer) since **kwargs passes it through automatically to MDLMTrainer.__init__. Co-authored-by: ZHZisZZ <[email protected]> * Align EditFlowTrainer with other trainers' pattern - Create EditFlowConfig nested class inside EditFlowTrainer with config fields - Move time_epsilon, normalize_per_position, max_w from __init__ kwargs to config - Update EditFlowTrainer.__init__ to accept args: EditFlowConfig - Update examples/editflow/sft.py and pt.py to inherit from EditFlowTrainer.EditFlowConfig - Simplify trainer instantiation to just pass args (config handles the rest) Co-authored-by: ZHZisZZ <[email protected]> * Add scheduler_cls back to EditFlowConfig for user customization Restored scheduler_cls field to EditFlowConfig so users can specify which kappa scheduler to use (e.g., "LinearKappaScheduler", "CubicKappaScheduler"). The scheduler is created from args.scheduler_cls if no scheduler is passed directly. Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler_cls to example scripts' TrainingArguments Keep EditFlowConfig clean with only trainer-specific params. The scheduler_cls field is now in the example scripts' TrainingArguments (which inherits from EditFlowConfig), and is used to create the scheduler that is passed to the EditFlowTrainer. Co-authored-by: ZHZisZZ <[email protected]> * Initial plan * Clean up unused imports and redundant boolean comparison (#5) * Initial plan * Address review comments: remove unused imports and fix redundant comparison Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]>
* fix bd3lm sampler * andi temp save * andi pt * temp save * Trainer refactor (#3) * refactor train args * Initial plan * Fix training arguments import and inheritance chain bugs - Fix incorrect import of TrainingArguments in mdlm.py (was importing from dllm.utils.data, should be dllm.utils.configs) - Fix DreamConfig.__post_init__ to call super().__post_init__() to ensure proper inheritance chain - Fix typo in sl.py: 'arg=args' -> 'args=args' for MDLMTrainer.__init__ and BD3LMTrainer.__init__ - Fix BD3LMAnDSLConfig to inherit from BD3LMTrainer.BD3LMConfig instead of MDLMTrainer.MDLMConfig (to include block_size field) Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler from MDLMConfig to __init__ kwarg - Remove scheduler field from MDLMConfig dataclass - Add scheduler as keyword argument in MDLMTrainer.__init__ with default LinearAlphaScheduler() - Propagate scheduler parameter through BD3LMTrainer.__init__ - Propagate scheduler parameter through MDLMAnDSLTrainer.__init__ and BD3LMAnDSLTrainer.__init__ - DreamTrainer inherits the scheduler kwarg automatically from MDLMTrainer Co-authored-by: ZHZisZZ <[email protected]> * Simplify: let inherited trainers pass scheduler via **kwargs Removed unnecessary explicit scheduler parameter from inherited trainers (BD3LMTrainer, MDLMAnDSLTrainer, BD3LMAnDSLTrainer) since **kwargs passes it through automatically to MDLMTrainer.__init__. Co-authored-by: ZHZisZZ <[email protected]> * Align EditFlowTrainer with other trainers' pattern - Create EditFlowConfig nested class inside EditFlowTrainer with config fields - Move time_epsilon, normalize_per_position, max_w from __init__ kwargs to config - Update EditFlowTrainer.__init__ to accept args: EditFlowConfig - Update examples/editflow/sft.py and pt.py to inherit from EditFlowTrainer.EditFlowConfig - Simplify trainer instantiation to just pass args (config handles the rest) Co-authored-by: ZHZisZZ <[email protected]> * Add scheduler_cls back to EditFlowConfig for user customization Restored scheduler_cls field to EditFlowConfig so users can specify which kappa scheduler to use (e.g., "LinearKappaScheduler", "CubicKappaScheduler"). The scheduler is created from args.scheduler_cls if no scheduler is passed directly. Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler_cls to example scripts' TrainingArguments Keep EditFlowConfig clean with only trainer-specific params. The scheduler_cls field is now in the example scripts' TrainingArguments (which inherits from EditFlowConfig), and is used to create the scheduler that is passed to the EditFlowTrainer. Co-authored-by: ZHZisZZ <[email protected]> * Initial plan * Clean up unused imports and redundant boolean comparison (#5) * Initial plan * Address review comments: remove unused imports and fix redundant comparison Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> * fix editflow scripts lacking main * update * Fix critical bugs in trainer metrics refactoring (#67) * Initial plan * Fix critical bugs in trainer refactoring: shape mismatch, division by zero, and data integrity Co-authored-by: ZHZisZZ <[email protected]> * Ensure dtype consistency in meter.update() calls Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ZHZisZZ <[email protected]> * update --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]>
Summary of ChangesHello @ZHZisZZ, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request primarily focuses on a significant refactoring of the training and sampling components, aiming to improve code structure, consistency, and maintainability. Key changes include standardizing trainer configurations using dataclasses, overhauling the metric logging system with Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
@copilot Please check for potential bugs and verify whether the changes are functionally equivalent before and after. |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and high-quality refactoring across the codebase. Key improvements include:
- A shift to structured dataclass-based configurations for trainers, enhancing maintainability.
- The implementation of a new, robust metrics system using
torchmetrics, replacing the older manual approach. - Cleaner and more efficient loss computation logic in the trainers.
- Numerous code cleanups, such as function renaming, deduplication, and making example scripts self-contained.
- Correctness improvements in the block diffusion sampler for batch generation.
The overall changes greatly improve the quality, clarity, and maintainability of the code. I have a few suggestions for further improvement, including a potential correctness issue regarding a removed chat template and a couple of minor style fixes.
| elif issubclass(model_cls, (A2DQwen2LMHeadModel, A2DQwen3LMHeadModel)): | ||
| tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) | ||
| tokenizer.eot_token = "<|im_end|>" | ||
| tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) |
There was a problem hiding this comment.
This refactoring combines the logic for A2DQwen2LMHeadModel and A2DQwen3LMHeadModel, which is great for deduplication. However, the very long and specific chat_template that was previously set for A2DQwen3LMHeadModel has been removed. This template is crucial for the model's chat and tool-use functionality. Its removal will likely lead to incorrect behavior. If this template is not being set by some other mechanism (e.g., from the tokenizer's config on the Hub), it should be restored, perhaps conditionally for A2DQwen3LMHeadModel.
| # if done.any(): | ||
| # new_block = torch.where( | ||
| # done.unsqueeze(1), | ||
| # torch.full_like(new_block, pad_id), | ||
| # new_block, | ||
| # ) |
There was a problem hiding this comment.
This commented-out block is a valuable optimization. By filling new blocks with pad_id for sequences that are already done, you can avoid wasteful computation during the diffusion steps. The current implementation processes these completed sequences unnecessarily. I recommend re-enabling this logic to improve efficiency in batch generation.
| # if done.any(): | |
| # new_block = torch.where( | |
| # done.unsqueeze(1), | |
| # torch.full_like(new_block, pad_id), | |
| # new_block, | |
| # ) | |
| if done.any(): | |
| new_block = torch.where( | |
| done.unsqueeze(1), | |
| torch.full_like(new_block, pad_id), | |
| new_block, | |
| ) |
dllm/core/trainers/utils/meters.py
Outdated
|
|
||
| # # 你原来传 nll_sum / token_cnt 的位置,现在这样传: | ||
| # meter.update("train", value=(nll_sum / token_cnt.clamp_min(1)), weight=token_cnt) | ||
| # meter.update("eval", value=(nll_sum / token_cnt.clamp_min(1)), weight=token_cnt) No newline at end of file |
dllm/core/trainers/utils/metrics.py
Outdated
| class PerplexityMetric(NLLMetric): | ||
| def compute(self) -> torch.Tensor: | ||
| mean_nll = super().compute() | ||
| return torch.exp(mean_nll) No newline at end of file |
…and BD3LM sampler improvements (#70) * Refactor mdlm, bd3lm, ppl tracking, and minor improvements (#68) * fix bd3lm sampler * andi temp save * andi pt * temp save * Trainer refactor (#3) * refactor train args * Initial plan * Fix training arguments import and inheritance chain bugs - Fix incorrect import of TrainingArguments in mdlm.py (was importing from dllm.utils.data, should be dllm.utils.configs) - Fix DreamConfig.__post_init__ to call super().__post_init__() to ensure proper inheritance chain - Fix typo in sl.py: 'arg=args' -> 'args=args' for MDLMTrainer.__init__ and BD3LMTrainer.__init__ - Fix BD3LMAnDSLConfig to inherit from BD3LMTrainer.BD3LMConfig instead of MDLMTrainer.MDLMConfig (to include block_size field) Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler from MDLMConfig to __init__ kwarg - Remove scheduler field from MDLMConfig dataclass - Add scheduler as keyword argument in MDLMTrainer.__init__ with default LinearAlphaScheduler() - Propagate scheduler parameter through BD3LMTrainer.__init__ - Propagate scheduler parameter through MDLMAnDSLTrainer.__init__ and BD3LMAnDSLTrainer.__init__ - DreamTrainer inherits the scheduler kwarg automatically from MDLMTrainer Co-authored-by: ZHZisZZ <[email protected]> * Simplify: let inherited trainers pass scheduler via **kwargs Removed unnecessary explicit scheduler parameter from inherited trainers (BD3LMTrainer, MDLMAnDSLTrainer, BD3LMAnDSLTrainer) since **kwargs passes it through automatically to MDLMTrainer.__init__. Co-authored-by: ZHZisZZ <[email protected]> * Align EditFlowTrainer with other trainers' pattern - Create EditFlowConfig nested class inside EditFlowTrainer with config fields - Move time_epsilon, normalize_per_position, max_w from __init__ kwargs to config - Update EditFlowTrainer.__init__ to accept args: EditFlowConfig - Update examples/editflow/sft.py and pt.py to inherit from EditFlowTrainer.EditFlowConfig - Simplify trainer instantiation to just pass args (config handles the rest) Co-authored-by: ZHZisZZ <[email protected]> * Add scheduler_cls back to EditFlowConfig for user customization Restored scheduler_cls field to EditFlowConfig so users can specify which kappa scheduler to use (e.g., "LinearKappaScheduler", "CubicKappaScheduler"). The scheduler is created from args.scheduler_cls if no scheduler is passed directly. Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler_cls to example scripts' TrainingArguments Keep EditFlowConfig clean with only trainer-specific params. The scheduler_cls field is now in the example scripts' TrainingArguments (which inherits from EditFlowConfig), and is used to create the scheduler that is passed to the EditFlowTrainer. Co-authored-by: ZHZisZZ <[email protected]> * Initial plan * Clean up unused imports and redundant boolean comparison (#5) * Initial plan * Address review comments: remove unused imports and fix redundant comparison Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> * fix editflow scripts lacking main * update (#66) * fix bd3lm sampler * andi temp save * andi pt * temp save * Trainer refactor (#3) * refactor train args * Initial plan * Fix training arguments import and inheritance chain bugs - Fix incorrect import of TrainingArguments in mdlm.py (was importing from dllm.utils.data, should be dllm.utils.configs) - Fix DreamConfig.__post_init__ to call super().__post_init__() to ensure proper inheritance chain - Fix typo in sl.py: 'arg=args' -> 'args=args' for MDLMTrainer.__init__ and BD3LMTrainer.__init__ - Fix BD3LMAnDSLConfig to inherit from BD3LMTrainer.BD3LMConfig instead of MDLMTrainer.MDLMConfig (to include block_size field) Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler from MDLMConfig to __init__ kwarg - Remove scheduler field from MDLMConfig dataclass - Add scheduler as keyword argument in MDLMTrainer.__init__ with default LinearAlphaScheduler() - Propagate scheduler parameter through BD3LMTrainer.__init__ - Propagate scheduler parameter through MDLMAnDSLTrainer.__init__ and BD3LMAnDSLTrainer.__init__ - DreamTrainer inherits the scheduler kwarg automatically from MDLMTrainer Co-authored-by: ZHZisZZ <[email protected]> * Simplify: let inherited trainers pass scheduler via **kwargs Removed unnecessary explicit scheduler parameter from inherited trainers (BD3LMTrainer, MDLMAnDSLTrainer, BD3LMAnDSLTrainer) since **kwargs passes it through automatically to MDLMTrainer.__init__. Co-authored-by: ZHZisZZ <[email protected]> * Align EditFlowTrainer with other trainers' pattern - Create EditFlowConfig nested class inside EditFlowTrainer with config fields - Move time_epsilon, normalize_per_position, max_w from __init__ kwargs to config - Update EditFlowTrainer.__init__ to accept args: EditFlowConfig - Update examples/editflow/sft.py and pt.py to inherit from EditFlowTrainer.EditFlowConfig - Simplify trainer instantiation to just pass args (config handles the rest) Co-authored-by: ZHZisZZ <[email protected]> * Add scheduler_cls back to EditFlowConfig for user customization Restored scheduler_cls field to EditFlowConfig so users can specify which kappa scheduler to use (e.g., "LinearKappaScheduler", "CubicKappaScheduler"). The scheduler is created from args.scheduler_cls if no scheduler is passed directly. Co-authored-by: ZHZisZZ <[email protected]> * Move scheduler_cls to example scripts' TrainingArguments Keep EditFlowConfig clean with only trainer-specific params. The scheduler_cls field is now in the example scripts' TrainingArguments (which inherits from EditFlowConfig), and is used to create the scheduler that is passed to the EditFlowTrainer. Co-authored-by: ZHZisZZ <[email protected]> * Initial plan * Clean up unused imports and redundant boolean comparison (#5) * Initial plan * Address review comments: remove unused imports and fix redundant comparison Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> * fix editflow scripts lacking main * update * Fix critical bugs in trainer metrics refactoring (#67) * Initial plan * Fix critical bugs in trainer refactoring: shape mismatch, division by zero, and data integrity Co-authored-by: ZHZisZZ <[email protected]> * Ensure dtype consistency in meter.update() calls Co-authored-by: ZHZisZZ <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: ZHZisZZ <[email protected]> * update --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> * update --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> * Remove unused utils * minor fix * fix meters and others * minor fix * minor fix --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]>
This pull request refactors and improves the BD3LM and MDLM trainer and sampler code, focusing on configuration management, loss computation, and sampling robustness. The changes introduce dataclass-based configuration, clarify and correct masking logic, and enhance per-sequence EOS handling in the sampler. Additionally, metric tracking is unified and modernized. The most important changes are grouped below:
Configuration and Trainer Refactoring:
MDLMConfigandBD3LMConfigdataclasses for structured configuration of trainers, replacing multiple positional arguments and improving clarity and maintainability (dllm/core/trainers/mdlm.py,dllm/core/trainers/bd3lm.py). [1] [2]OnEvaluateMetricsCallbackinstead of the previousEpochPPLMeter(dllm/core/trainers/mdlm.py,dllm/core/trainers/bd3lm.py). [1] [2]Loss Computation and Masking Logic:
masked_indicestomasked_mask, andtoken_cnt_per_seqtomaskable_mask), and updated all related logic to use these clearer names. This change ensures consistency and correctness in how loss is computed and normalized (dllm/core/trainers/mdlm.py,dllm/core/trainers/bd3lm.py). [1] [2] [3] [4] [5]dllm/core/trainers/bd3lm.py).loss_norm_typeand fixed normalization logic for batch, sequence, and token modes (dllm/core/trainers/bd3lm.py).Sampler Improvements:
dllm/core/samplers/bd3lm.py). [1] [2]dllm/core/samplers/bd3lm.py). [1] [2]build_staircase_attention_maskto_prepare_for_sampling,diffusion_step_blockto_diffusion_step_block) (dllm/core/samplers/bd3lm.py). [1] [2]Documentation and API Consistency:
README.mdto use the correct mapping function (dllm.utils.default_sft_map_fn) for consistency with code changes (README.md). [1] [2]Other Code Quality Improvements:
dllm/core/trainers/bd3lm.py). [1] [2]dllm/core/trainers/bd3lm.py). [1] [2] [3]These changes collectively modernize the codebase, improve maintainability, and ensure more robust and correct training and sampling behavior.