diff --git a/README.md b/README.md index e87bc6eb..724f5d7d 100644 --- a/README.md +++ b/README.md @@ -219,7 +219,7 @@ See [Features](#features) for specific training recipes. # Preprocess SFT data python dllm/tools/preprocess_sft_dataset.py \ --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ - --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \ + --sft_map_fn_path "dllm.utils.default_sft_map_fn" \ --dataset_args "allenai/tulu-3-sft-mixture" \ --output_dir "data/sft/llada/tulu-3-sft-mixture" \ --num_proc 64 @@ -238,7 +238,7 @@ See [Features](#features) for specific training recipes. # Preprocess SFT data + python dllm/tools/preprocess_sft_dataset.py \ + --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ - + --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \ + + --sft_map_fn_path "dllm.utils.default_sft_map_fn" \ + --dataset_args "allenai/tulu-3-sft-mixture" \ + --output_dir "data/sft/llada/tulu-3-sft-mixture" \ + --num_proc 64 diff --git a/dllm/core/samplers/bd3lm.py b/dllm/core/samplers/bd3lm.py index 15d30b1d..ec2a3f96 100644 --- a/dllm/core/samplers/bd3lm.py +++ b/dllm/core/samplers/bd3lm.py @@ -13,7 +13,7 @@ from dllm.core.samplers.utils import add_gumbel_noise, get_num_transfer_tokens -def build_staircase_attention_mask( +def _prepare_for_sampling( x: torch.Tensor, block_size: int, pad_token_id: int, @@ -81,7 +81,7 @@ def build_staircase_attention_mask( return attn_mask, position_ids -def diffusion_step_block( +def _diffusion_step_block( logits: torch.Tensor, # [B, L, V] x_block: torch.Tensor, # [B, L] mask_block: torch.Tensor, # [B, L] bool @@ -201,9 +201,10 @@ def sample( mask_id = self.tokenizer.mask_token_id bos_id = self.tokenizer.bos_token_id pad_id = self.tokenizer.pad_token_id # used as padding here + eos_id = self.tokenizer.eos_token_id # ---- normalize inputs to tensors ---- - # If right_shift_logits is true and a sequence has length 0, replace that sequence with [eos]. + # If right_shift_logits is true and a sequence has length 0, replace that sequence with [bos]. if right_shift_logits: inputs = [ [bos_id] if isinstance(p, list) and len(p) == 0 else p for p in inputs @@ -228,16 +229,21 @@ def sample( # ========================================================== # 1) Initialize with prompt only (left padded with pad_id) + # pad prefix length to a multiple of block_size # ========================================================== + padded_prompt_len = ( + (max_prompt_len + block_size - 1) // block_size + ) * block_size + x = torch.full( - (B, max_prompt_len), + (B, padded_prompt_len), pad_id, dtype=torch.long, device=self.model.device, ) for b, p in enumerate(inputs): L = prompt_lens[b] - offset = max_prompt_len - L # left padding + offset = padded_prompt_len - L # left padding x[b, offset : offset + L] = p # Tokens considered "given" for unconditional branch in CFG. @@ -248,6 +254,9 @@ def sample( ) unmasked_index = unmasked_index & (~keep_mask) + # track done per sequence (EOS) + done = torch.zeros((B,), dtype=torch.bool, device=self.model.device) + # ---- block scheduling ---- num_blocks = math.ceil(max_new_tokens / block_size) if steps_per_block is None: @@ -260,15 +269,13 @@ def sample( # 2) Block-by-block generation loop # ========================================================== for b_idx in range(num_blocks): - # Align sampling block to physical block boundaries + if done.all(): + break + T_prefix = x.shape[1] # current total length before appending this block - offset = T_prefix % block_size - if offset == 0: - block_room = block_size - else: - block_room = block_size - offset - cur_block_len = min(block_room, max_new_tokens - generated) + # With padded_prompt_len aligned, we always append whole blocks (except possibly final) + cur_block_len = min(block_size, max_new_tokens - generated) if cur_block_len <= 0: break @@ -278,7 +285,7 @@ def sample( x_prefix = x # [B, T_prefix] B_cur, T_prefix = x_prefix.shape - prefix_attn, prefix_pos = build_staircase_attention_mask( + prefix_attn, prefix_pos = _prepare_for_sampling( x=x_prefix, block_size=block_size, pad_token_id=pad_id, @@ -317,6 +324,13 @@ def sample( new_block = torch.full( (B, cur_block_len), mask_id, dtype=torch.long, device=self.model.device ) + # if done.any(): + # new_block = torch.where( + # done.unsqueeze(1), + # torch.full_like(new_block, pad_id), + # new_block, + # ) + x = torch.cat([x, new_block], dim=1) # [B, T_prefix + cur_block_len] unmasked_index = torch.cat( @@ -342,7 +356,7 @@ def sample( effective_steps = num_transfer_tokens.size(1) # Full staircase attention mask + pos for prefix + current block - full_attention_mask, full_position_ids = build_staircase_attention_mask( + full_attention_mask, full_position_ids = _prepare_for_sampling( x=x, block_size=block_size, pad_token_id=pad_id, @@ -406,7 +420,7 @@ def sample( logits_block = shifted # ---- One diffusion step over this block ---- - x_block_updated = diffusion_step_block( + x_block_updated = _diffusion_step_block( logits=logits_block, x_block=x_block, mask_block=mask_block, @@ -421,8 +435,10 @@ def sample( if histories is not None: histories.append(x.clone()) - if self.tokenizer.eos_token_id in x[:, T_prefix:T_total]: - break + # per-sequence EOS stopping (after finishing denoising the block) + if eos_id is not None: + eos_in_block = (x[:, T_prefix:T_total] == eos_id).any(dim=1) + done = done | eos_in_block generated += cur_block_len @@ -437,7 +453,7 @@ def sample( @torch.no_grad() def infill( self, - inputs: list[torch.Tensor, list], + inputs: list[torch.Tensor | list], config: SamplerConfig | None = None, **kwargs, ) -> SamplerOutput: diff --git a/dllm/core/trainers/bd3lm.py b/dllm/core/trainers/bd3lm.py index 5beb8c36..b6f9b95b 100644 --- a/dllm/core/trainers/bd3lm.py +++ b/dllm/core/trainers/bd3lm.py @@ -17,7 +17,6 @@ from dllm.utils.collators import CollatorWrapper from .mdlm import MDLMTrainer -from .utils import EpochPPLMeter @dataclass @@ -40,7 +39,7 @@ def before(self, features): return features -def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None): +def _bd3lm_attention_mask(b, h, q_idx, kv_idx, block_size=None, n=None): """ Constructs the specialized block diffusion attention mask for training composed of three masks: @@ -85,17 +84,18 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None): class BD3LMTrainer(MDLMTrainer): + @dataclass + class BD3LMConfig(MDLMTrainer.MDLMConfig): + block_size: int = 32 + def __init__( self, - block_size: int = 32, - *args, + args: BD3LMConfig, + *pargs, **kwargs, ): - super().__init__(*args, **kwargs) - self.block_size = block_size - - self.epoch_meter = EpochPPLMeter(self, train_prefix="train", eval_prefix="eval") - self.add_callback(self.epoch_meter) + super().__init__(args=args, *pargs, **kwargs) + self.block_size = args.block_size def compute_loss( self, @@ -126,7 +126,7 @@ def compute_loss( inputs.get("attention_mask", None), ) b, l = input_ids.shape - token_cnt_per_seq = torch.sum(labels != -100, dim=1, keepdim=True) # [b, 1] + maskable_mask = labels != -100 # [b, l] # === 1. Sample diffusion timesteps === # Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0. @@ -139,12 +139,12 @@ def compute_loss( # === 2. Apply stochastic masking === # Tokens are masked independently according to p_mask(t). # Positions with label = -100 are excluded (ignored in loss). - masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & ( - labels != -100 - ) + masked_mask = ( + torch.rand((b, l), device=input_ids.device) < p_mask + ) & maskable_mask # Replace masked tokens with the special [MASK] token. noised_input_ids = torch.where( - masked_indices, self.processing_class.mask_token_id, input_ids + masked_mask, self.processing_class.mask_token_id, input_ids ) # === 3. Forward pass through the model (block-diffusion) === @@ -155,7 +155,7 @@ def compute_loss( # [TODO]: others like flash attention 2 if self.accelerator.unwrap_model(model).config._attn_implementation == "sdpa": - attention_mask = block_diff_mask( + attention_mask = _bd3lm_attention_mask( b=None, h=None, q_idx=torch.arange(l * 2)[:, None], @@ -174,7 +174,7 @@ def compute_loss( from torch.nn.attention.flex_attention import create_block_mask attention_mask = create_block_mask( - partial(block_diff_mask, block_size=self.block_size, n=l), + partial(_bd3lm_attention_mask, block_size=self.block_size, n=l), B=None, H=None, Q_LEN=l * 2, @@ -201,46 +201,51 @@ def compute_loss( # === 4. Handle degenerate cases (no tokens masked) === # If no positions were masked, return a zero loss to keep gradients valid. # This step is necessary for Deepspeed Zero-{2,3} - if not masked_indices.any(): - self.epoch_meter.update( - split="train" if model.training else "eval", - nll_sum=logits.sum() * 0.0, - token_cnt=token_cnt_per_seq.sum(), - ) - return ( - (logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0 + if not masked_mask.any(): + zero = logits.sum() * 0.0 # scalar, grad-safe + self.meter.update( + split="train" if model.training else "eval", + value=torch.zeros_like(maskable_mask, dtype=logits.dtype), + weight=maskable_mask.to(dtype=logits.dtype).detach(), ) + return (zero, outputs) if return_outputs else zero # === 5. Compute per-token loss weights === # Depending on the configuration, weights may depend on timestep t # (e.g., scheduler-based) or be uniform (ones). loss_weights = self._compute_loss_weights( - t=t, inputs=inputs, masked_indices=masked_indices + t=t, inputs=inputs, masked_mask=masked_mask ) # === 6. Compute weighted cross-entropy === - # Only masked tokens contribute to the loss. - assert (input_ids[masked_indices] == labels[masked_indices]).all() - token_loss = F.cross_entropy( - logits[masked_indices], input_ids[masked_indices], reduction="none" + # Sanity check: ensure input_ids and labels match at valid positions + assert ( + input_ids[maskable_mask] == labels[maskable_mask] + ).all(), "Mismatch between input_ids and labels at valid positions" + + token_nll = F.cross_entropy( + logits.transpose(1, 2), # [b, V, l] + input_ids, # [b, l] + reduction="none", # [b, l] + ) + token_nll = token_nll * loss_weights * masked_mask.to(token_nll.dtype) # [b, l] + + self.meter.update( + split="train" if model.training else "eval", + value=token_nll.detach(), + weight=maskable_mask.to(dtype=logits.dtype).detach(), ) - token_loss = token_loss * loss_weights[masked_indices] # === 7. Normalize loss === - if self.loss_normalization_type == "batch": - token_loss_normalized = token_loss / b - elif self.loss_normalization_type == "sequence": - token_loss_normalized = token_loss / token_cnt_per_seq.expand(-1, l)[masked_indices] / b - elif self.loss_normalization_type == "token": - token_loss_normalized = token_loss / token_cnt_per_seq.sum() + if self.loss_norm_type == "batch": + token_nll /= b + elif self.loss_norm_type == "sequence": + token_nll /= maskable_mask.sum(-1, keepdim=True).clamp_min(1) * b + elif self.loss_norm_type == "token": + token_nll /= maskable_mask.sum().clamp_min(1) else: - raise ValueError("Invalid loss_normalization_type.") - loss = token_loss_normalized.sum() + raise ValueError("Invalid loss_norm_type.") + loss = token_nll.sum() # === 8. Return final loss (and optionally model outputs) === - self.epoch_meter.update( - split="train" if model.training else "eval", - nll_sum=token_loss.sum(), - token_cnt=token_cnt_per_seq.sum(), - ) # `nll_sum / token_cnt` is equivalent to `loss` when `self.loss_normalization_type == "token"`` return (loss, outputs) if return_outputs else loss diff --git a/dllm/core/trainers/mdlm.py b/dllm/core/trainers/mdlm.py index 316669d8..5bb7d771 100644 --- a/dllm/core/trainers/mdlm.py +++ b/dllm/core/trainers/mdlm.py @@ -9,6 +9,7 @@ """ from typing import Any +from dataclasses import dataclass import torch import torch.nn as nn @@ -16,35 +17,48 @@ import transformers from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler +from dllm.utils.configs import TrainingArguments from dllm.utils.data import prepend_bos -from .utils import EpochPPLMeter +from .utils import NLLMetric, PerplexityMetric, OnEvaluateMetricsCallback class MDLMTrainer(transformers.Trainer): + @dataclass + class MDLMConfig(TrainingArguments): + time_epsilon: float = 1e-3 + loss_weight_type: str = "scheduler" # "scheduler", "uniform" + loss_norm_type: str = "sequence" # "batch", "sequence", "token" + right_shift_logits: bool = False + def __init__( self, + args: MDLMConfig, scheduler: BaseAlphaScheduler | None = None, - time_epsilon: float = 1e-3, - loss_weight_type: str = "scheduler", # "scheduler", "uniform" - loss_normalization_type: str = "sequence", # "batch", "sequence", "token" - right_shift_logits: bool = False, - *args, + *pargs, **kwargs, ): - super().__init__(*args, **kwargs) + super().__init__(args=args, *pargs, **kwargs) - if not (0.0 < time_epsilon < 1.0): + if not (0.0 < args.time_epsilon < 1.0): raise ValueError("time_epsilon must be in (0, 1)") - self.scheduler = scheduler or LinearAlphaScheduler() - self.time_epsilon = time_epsilon - self.loss_weight_type = loss_weight_type - self.loss_normalization_type = loss_normalization_type - self.right_shift_logits = right_shift_logits - - self.epoch_meter = EpochPPLMeter(self, train_prefix="train", eval_prefix="eval") - self.add_callback(self.epoch_meter) + self.scheduler = scheduler if scheduler is not None else LinearAlphaScheduler() + self.time_epsilon = args.time_epsilon + self.loss_weight_type = args.loss_weight_type + self.loss_norm_type = args.loss_norm_type + self.right_shift_logits = args.right_shift_logits + + # self.epoch_meter = EpochPPLMeter(self) + self.meter = OnEvaluateMetricsCallback( + trainer=self, + splits=("train", "eval"), + metrics_map={ + "diff_nll": NLLMetric(), + "diff_ppl": PerplexityMetric(), + }, + ) + self.add_callback(self.meter) def _preprocess_inputs(self, inputs): if self.right_shift_logits: @@ -133,7 +147,7 @@ def compute_loss( inputs.get("attention_mask", None), ) b, l = input_ids.shape - token_cnt_per_seq = torch.sum(labels != -100, dim=1, keepdim=True) # [b, 1] + maskable_mask = labels != -100 # [b, l] # === 1. Sample diffusion timesteps === # Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0. @@ -146,12 +160,12 @@ def compute_loss( # === 2. Apply stochastic masking === # Tokens are masked independently according to p_mask(t). # Positions with label = -100 are excluded (ignored in loss). - masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & ( - labels != -100 - ) + masked_mask = ( + torch.rand((b, l), device=input_ids.device) < p_mask + ) & maskable_mask # Replace masked tokens with the special [MASK] token. noised_input_ids = torch.where( - masked_indices, self.processing_class.mask_token_id, input_ids + masked_mask, self.processing_class.mask_token_id, input_ids ) # === 3. Forward pass through the model === @@ -163,46 +177,51 @@ def compute_loss( # === 4. Handle degenerate cases (no tokens masked) === # If no positions were masked, return a zero loss to keep gradients valid. # This step is necessary for Deepspeed Zero-{2,3} - if not masked_indices.any(): - self.epoch_meter.update( - split="train" if model.training else "eval", - nll_sum=logits.sum() * 0.0, - token_cnt=token_cnt_per_seq.sum(), - ) - return ( - (logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0 + if not masked_mask.any(): + zero = logits.sum() * 0.0 # scalar, grad-safe + self.meter.update( + split="train" if model.training else "eval", + value=torch.zeros_like(maskable_mask, dtype=logits.dtype), + weight=maskable_mask.to(dtype=logits.dtype).detach(), ) + return (zero, outputs) if return_outputs else zero # === 5. Compute per-token loss weights === # Depending on the configuration, weights may depend on timestep t # (e.g., scheduler-based) or be uniform (ones). loss_weights = self._compute_loss_weights( - t=t, inputs=inputs, masked_indices=masked_indices + t=t, inputs=inputs, masked_mask=masked_mask ) # === 6. Compute weighted cross-entropy === - # Only masked tokens contribute to the loss. - assert (input_ids[masked_indices] == labels[masked_indices]).all() - token_loss = F.cross_entropy( - logits[masked_indices], input_ids[masked_indices], reduction="none" + # Sanity check: ensure input_ids and labels match at valid positions + assert ( + input_ids[maskable_mask] == labels[maskable_mask] + ).all(), "Mismatch between input_ids and labels at valid positions" + + token_nll = F.cross_entropy( + logits.transpose(1, 2), # [b, V, l] + input_ids, # [b, l] + reduction="none", # [b, l] + ) + token_nll = token_nll * loss_weights * masked_mask.to(token_nll.dtype) # [b, l] + + self.meter.update( + split="train" if model.training else "eval", + value=token_nll.detach(), + weight=maskable_mask.to(dtype=logits.dtype).detach(), ) - token_loss = token_loss * loss_weights[masked_indices] # === 7. Normalize loss === - if self.loss_normalization_type == "batch": - token_loss_normalized = token_loss / b - elif self.loss_normalization_type == "sequence": - token_loss_normalized = token_loss / token_cnt_per_seq.expand(-1, l)[masked_indices] / b - elif self.loss_normalization_type == "token": - token_loss_normalized = token_loss / token_cnt_per_seq.sum() + if self.loss_norm_type == "batch": + token_nll /= b + elif self.loss_norm_type == "sequence": + token_nll /= maskable_mask.sum(-1, keepdim=True).clamp_min(1) * b + elif self.loss_norm_type == "token": + token_nll /= maskable_mask.sum().clamp_min(1) else: - raise ValueError("Invalid loss_normalization_type.") - loss = token_loss_normalized.sum() + raise ValueError("Invalid loss_norm_type.") + loss = token_nll.sum() # === 8. Return final loss (and optionally model outputs) === - self.epoch_meter.update( - split="train" if model.training else "eval", - nll_sum=token_loss.sum(), - token_cnt=token_cnt_per_seq.sum(), - ) # `nll_sum / token_cnt` is equivalent to `loss` when `self.loss_normalization_type == "token"`` return (loss, outputs) if return_outputs else loss diff --git a/dllm/core/trainers/utils.py b/dllm/core/trainers/utils.py index 2cb9ceff..ce150437 100644 --- a/dllm/core/trainers/utils.py +++ b/dllm/core/trainers/utils.py @@ -1,30 +1,18 @@ import math - import torch import transformers class EpochPPLMeter(transformers.TrainerCallback): """ - Keeps running sums for dataset-level NLL/token and logs PPL once per epoch. - - Usage: - - Trainer calls: self.ppl_meter.update(split, nll_sum, token_cnt) - - Callback hooks: - * on_epoch_begin: reset train accumulators - * on_epoch_end: finalize+log train PPL - * on_evaluate: finalize+log eval PPL (one per evaluate call) + Keeps running sums for dataset-level NLL/token and logs PPL. + Convention: + - Train: keys are unprefixed, e.g. "diff_nll", "diff_ppl" + - Eval : keys are prefixed with "eval_", e.g. "eval_diff_nll", "eval_diff_ppl" """ - def __init__( - self, - trainer: "transformers.Trainer", - train_prefix: str = "train", - eval_prefix: str = "eval", - ): + def __init__(self, trainer: "transformers.Trainer"): self.trainer = trainer - self.train_prefix = train_prefix - self.eval_prefix = eval_prefix self._train_nll_sum = 0.0 self._train_token_cnt = 0.0 @@ -41,8 +29,9 @@ def reset(self, split: str) -> None: else: raise ValueError(f"Unknown split={split}") - def update(self, split: str, nll_sum: torch.Tensor, token_cnt: torch.Tensor) -> None: - # detach -> float64 -> python float + def update( + self, split: str, nll_sum: torch.Tensor, token_cnt: torch.Tensor + ) -> None: nll_sum_f = float(nll_sum.detach().double().cpu().item()) tok_cnt_f = float(token_cnt.detach().double().cpu().item()) @@ -60,9 +49,8 @@ def _finalize(self, split: str): All-reduce (sum) across processes, then compute: mean_nll = total_nll / total_tokens ppl = exp(mean_nll) - - Returns (mean_nll, ppl) as python floats, or (None, None) if no tokens. - Also resets that split after finalizing. + Returns (mean_nll, ppl) or (None, None) if no tokens. + Resets the split accumulators when called. """ if split == "train": local_nll, local_tok = self._train_nll_sum, self._train_token_cnt @@ -93,22 +81,44 @@ def _finalize(self, split: str): # ---- callback hooks ---- - def on_epoch_begin(self, args, state, control, **kwargs): + def on_train_begin(self, args, state, control, **kwargs): self.reset("train") return control - def on_epoch_end(self, args, state, control, **kwargs): - mean_nll, ppl = self._finalize("train") - if mean_nll is not None and self.trainer.is_world_process_zero(): - logs = {f"{self.train_prefix}_nll": mean_nll, f"{self.train_prefix}_ppl": ppl} - self.trainer.log(logs) - print(f"[epoch {state.epoch}] {self.train_prefix}_nll={mean_nll:.6f} {self.train_prefix}_ppl={ppl:.6f}") + def on_evaluate_begin(self, args, state, control, **kwargs): + self.reset("eval") return control def on_evaluate(self, args, state, control, metrics=None, **kwargs): - mean_nll, ppl = self._finalize("eval") - if mean_nll is not None and self.trainer.is_world_process_zero(): - logs = {f"{self.eval_prefix}_nll": mean_nll, f"{self.eval_prefix}_ppl": ppl} - self.trainer.log(logs) - print(f"[epoch {state.epoch}] {self.eval_prefix}_nll={mean_nll:.6f} {self.eval_prefix}_ppl={ppl:.6f}") + train_mean_nll, train_ppl = self._finalize("train") + eval_mean_nll, eval_ppl = self._finalize("eval") + + if self.trainer.is_world_process_zero(): + logs = {} + + # TRAIN: NO "train_" prefix + if train_mean_nll is not None: + logs.update( + { + "diff_nll": train_mean_nll, + "diff_ppl": train_ppl, + } + ) + + # EVAL: MUST be "eval_" prefixed + if eval_mean_nll is not None: + logs.update( + { + "eval_diff_nll": eval_mean_nll, + "eval_diff_ppl": eval_ppl, + } + ) + + if logs: + self.trainer.log(logs) + print( + f"[step {state.global_step} epoch {state.epoch}] " + + " ".join(f"{k}={v:.6f}" for k, v in logs.items()) + ) + return control diff --git a/dllm/core/trainers/utils/__init__.py b/dllm/core/trainers/utils/__init__.py new file mode 100644 index 00000000..b01daf8b --- /dev/null +++ b/dllm/core/trainers/utils/__init__.py @@ -0,0 +1,3 @@ +from . import meters, metrics +from .meters import * +from .metrics import * diff --git a/dllm/core/trainers/utils/meters.py b/dllm/core/trainers/utils/meters.py new file mode 100644 index 00000000..aaf5d99f --- /dev/null +++ b/dllm/core/trainers/utils/meters.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from typing import Any, Dict, Iterable, Optional +import copy + +import torch +import transformers +import torchmetrics + + +def _ddp_initialized() -> bool: + return torch.distributed.is_available() and torch.distributed.is_initialized() + + +class BaseMetricsCallback(transformers.TrainerCallback): + """ + Generic split-aware metric accumulator for HF Trainer. + + Fixes vs old version: + 1) Per-split metrics are independent (deep-copied) to avoid train/eval contamination. + 2) DDP-safe: sync/compute/reset run on ALL ranks; only rank0 logs/prints (no deadlock). + 3) Metrics are moved to trainer device (avoids CPU/GPU mismatch). + 4) Optional dtype set if metric supports it. + + Smart key prefixing: + - split == "train": no prefix (e.g., "loss", "ppl") + - otherwise : f"{split}_" prefix (e.g., "eval_loss", "test_ppl") + + You provide: + - metrics_map: + * {name: Metric} -> broadcast to all splits (deep-copied per split) + * {split: {name: Metric}} -> per-split metrics + update(): + - calls metric.update(*args, **kwargs) by default + """ + + def __init__( + self, + trainer: "transformers.Trainer", + splits: Iterable[str] = ("train", "eval"), + metrics_map: Optional[Dict[str, Any]] = None, + dtype: torch.dtype = torch.float64, + ): + self.trainer = trainer + self.splits = tuple(splits) + metrics_map = metrics_map or {} + + # Create per-split independent metric dicts + self._metrics: Dict[str, Dict[str, torchmetrics.Metric]] = {} + + # Detect broadcast map: {name: Metric} + is_broadcast = len(metrics_map) > 0 and all( + isinstance(v, torchmetrics.Metric) for v in metrics_map.values() + ) + + device = getattr(self.trainer.args, "device", torch.device("cpu")) + + for split in self.splits: + if is_broadcast: + # IMPORTANT: deepcopy so each split has independent state + mdict = {k: copy.deepcopy(v) for k, v in metrics_map.items()} + else: + mdict = { + k: copy.deepcopy(v) for k, v in metrics_map.get(split, {}).items() + } + + # Configure dtype / device + for m in mdict.values(): + # Many torchmetrics ignore this, but keep your hook + if hasattr(m, "set_dtype"): + m.set_dtype(dtype) + # Ensure state buffers are on the right device + try: + m.to(device) + except Exception: + pass + + self._metrics[split] = mdict + + # ---------- key naming ---------- + + @staticmethod + def key_for(split: str, name: str) -> str: + return name if split == "train" else f"{split}_{name}" + + # ---------- lifecycle ---------- + + def reset(self, split: str) -> None: + for m in self._metrics[split].values(): + m.reset() + + @torch.no_grad() + def update(self, split: str, *args, **kwargs) -> None: + for m in self._metrics[split].values(): + m.update(*args, **kwargs) + + @torch.no_grad() + def finalize(self, split: str) -> Dict[str, float]: + """ + DDP-safe finalize: + - Must be called on ALL ranks (because sync uses collectives). + - Returns local dict of python floats. + - Resets split metrics. + """ + mdict = self._metrics[split] + + # Make sure metrics live on current device (in case trainer device changes) + device = getattr(self.trainer.args, "device", torch.device("cpu")) + for m in mdict.values(): + try: + m.to(device) + except Exception: + pass + + # Sync across ranks (collectives) -- MUST run on all ranks + if _ddp_initialized(): + for m in mdict.values(): + # torchmetrics usually has sync/unsync; prefer sync if available + if hasattr(m, "sync"): + m.sync() + + out: Dict[str, float] = {} + for name, m in mdict.items(): + v = m.compute() + if isinstance(v, torch.Tensor): + if v.numel() == 0: + continue + v = v.detach() + v = v.item() if v.numel() == 1 else v.double().mean().cpu().item() + out[name] = float(v) + + # IMPORTANT: reset after compute so next window starts clean + self.reset(split) + return out + + @torch.no_grad() + def log_and_print( + self, + state: transformers.TrainerState, + splits: Iterable[str] | None = None, + ) -> None: + """ + DDP-safe: + - finalize() (and thus sync/compute/reset) runs on ALL ranks + - only rank0 logs/prints + """ + splits = self.splits if splits is None else tuple(splits) + + # All ranks finalize (avoid DDP deadlock) + all_vals: Dict[str, Dict[str, float]] = {} + for split in splits: + if split in self._metrics: + all_vals[split] = self.finalize(split) + + # Only rank0 logs/prints + if not self.trainer.is_world_process_zero(): + return + + logs: Dict[str, float] = {} + for split, vals in all_vals.items(): + logs.update({self.key_for(split, k): v for k, v in vals.items()}) + + if logs: + self.trainer.log(logs) + print( + f"[step {state.global_step} epoch {state.epoch}] " + + " ".join(f"{k}={v:.6f}" for k, v in logs.items()) + ) + + # ---------- HF callback hooks (optional defaults) ---------- + + def on_train_begin(self, args, state, control, **kwargs): + if "train" in self._metrics: + self.reset("train") + return control + + def on_evaluate_begin(self, args, state, control, **kwargs): + if "eval" in self._metrics: + self.reset("eval") + return control + + +class OnEvaluateMetricsCallback(BaseMetricsCallback): + def on_evaluate(self, args, state, control, metrics=None, **kwargs): + # Log both train + eval by default (matches your previous behavior). + self.log_and_print(state, splits=("train", "eval")) + return control diff --git a/dllm/core/trainers/utils/metrics.py b/dllm/core/trainers/utils/metrics.py new file mode 100644 index 00000000..926093b6 --- /dev/null +++ b/dllm/core/trainers/utils/metrics.py @@ -0,0 +1,12 @@ +import torch +import torchmetrics + + +class NLLMetric(torchmetrics.aggregation.MeanMetric): + pass + + +class PerplexityMetric(NLLMetric): + def compute(self) -> torch.Tensor: + mean_nll = super().compute() + return torch.exp(mean_nll) diff --git a/dllm/data/s1k.py b/dllm/data/s1k.py new file mode 100644 index 00000000..a68f87cc --- /dev/null +++ b/dllm/data/s1k.py @@ -0,0 +1,43 @@ +from datasets import DatasetDict, load_dataset + +# messages = [ +# {"role": "user", "content": "Solve 13 * 17"}, +# { +# "role": "assistant", +# "reasoning_content": "We need to multiply 13 and 17 step by step.", +# "content": "13 * 17 = 221." +# } +# ] + + +def load_dataset_s1k(dataset_name_or_path: str) -> DatasetDict: + + dataset = load_dataset(dataset_name_or_path) + + def map_fn(example): + + return { + "messages": [ + {"role": "user", "content": example["question"]}, + { + "role": "assistant", + "reasoning_content": example["thinking_trajectories"][0], + "content": example["attempt"], + }, + ] + } + + dataset = dataset.map( + map_fn, remove_columns=dataset["train"].column_names, num_proc=4 + ) + return dataset + + +if __name__ == "__main__": + from dllm.utils import resolve_with_base_env + + dataset_name_or_path = resolve_with_base_env( + "simplescaling/s1K", "BASE_DATASETS_DIR" + ) + dataset = load_dataset_s1k(dataset_name_or_path) + breakpoint() diff --git a/dllm/pipelines/a2d/convert.py b/dllm/pipelines/a2d/convert.py index 1b0ffc64..38ddf2a7 100644 --- a/dllm/pipelines/a2d/convert.py +++ b/dllm/pipelines/a2d/convert.py @@ -6,7 +6,6 @@ import dllm A2D_CONFIG_MAP = { - # "gpt2": dllm.pipelines.a2d.A2DGPT2Config, "llama": dllm.pipelines.a2d.A2DLlamaConfig, "qwen2": dllm.pipelines.a2d.A2DQwen2Config, "qwen3": dllm.pipelines.a2d.A2DQwen3Config, @@ -17,6 +16,7 @@ class ScriptArguments: model_name_or_path: str = "Qwen/Qwen2.5-0.5B" output_dir: str = "models/a2d/Qwen2.5-0.5B" + random_init: bool = False def __post_init__(self): self.model_name_or_path = dllm.utils.resolve_with_base_env( @@ -54,10 +54,14 @@ def main(): tgt_config = tgt_config_cls(**cfg_dict) with dllm.utils.init_device_context_manager(): - # Build A2D model tgt_model = transformers.AutoModel.from_config(tgt_config) - # Direct weight copy (models match exactly) - tgt_model.load_state_dict(src_model.state_dict()) + + if not args.random_init: + missing, unexpected = tgt_model.load_state_dict( + src_model.state_dict(), strict=False + ) + print("missing:", missing) + print("unexpected:", unexpected) # Save model and config tgt_model.save_pretrained(args.output_dir) diff --git a/dllm/pipelines/a2d/models/gpt2/modeling_gpt2.py b/dllm/pipelines/a2d/models/gpt2/modeling_gpt2.py deleted file mode 100644 index ffeffb8d..00000000 --- a/dllm/pipelines/a2d/models/gpt2/modeling_gpt2.py +++ /dev/null @@ -1,304 +0,0 @@ -from typing import Optional, Union - -import torch -from torch import nn - -import transformers -from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa -from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions -from transformers.utils import logging -from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask - -from transformers.models.gpt2.modeling_gpt2 import GPT2Attention as GPT2Attention - -if transformers.utils.is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size - from torch.nn.attention.flex_attention import BlockMask, create_block_mask -else: - BlockMask = torch.Tensor - -logger = logging.get_logger(__name__) - - -class A2DGPT2Config(transformers.GPT2Config): - model_type = "a2d-gpt2" # <- NEW model_type - - -# >>> A2D modification: -# Minimal override of GPT2Attention to disable the internal causal mask while -# keeping all original behavior / comments intact. -class A2DGPT2Attention(GPT2Attention): - def __init__(self, config, is_cross_attention=False, layer_idx=None): - super().__init__(config=config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) - - # Disable causal behavior for all attention backends (eager/sdpa/flash) - # This ensures full bidirectional attention as required by A2D. - self.is_causal = False # <<< key change - - # Replace causal lower-triangular mask with an all-True mask - # so eager/_upcast path will not zero-out future positions. - if hasattr(self, "bias"): - full_bias = torch.ones_like(self.bias, dtype=torch.bool) - self.register_buffer("bias", full_bias, persistent=False) - - -class A2DGPT2Model(transformers.GPT2Model): - - def __init__(self, config): - super().__init__(config) - - # >>> A2D modification: - # Replace original causal GPT2Attention with the non-causal version above. - for i, block in enumerate(self.h): - block.attn = A2DGPT2Attention(config, is_cross_attention=False, layer_idx=i) - if config.add_cross_attention and hasattr(block, "crossattention"): - block.crossattention = A2DGPT2Attention(config, is_cross_attention=True, layer_idx=i) - - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - **kwargs, - ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]: - r""" - input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input - sequence tokens in the vocabulary. - - If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as - `input_ids`. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - batch_size = input_ids.shape[0] - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - batch_size = inputs_embeds.shape[0] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - device = input_ids.device if input_ids is not None else inputs_embeds.device - - if token_type_ids is not None: - token_type_ids = token_type_ids.view(-1, input_shape[-1]) - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." - ) - use_cache = False - - # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder - if use_cache: - if past_key_values is None: - past_key_values = DynamicCache(config=self.config) - elif isinstance(past_key_values, tuple): - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " - "You should pass an instance of `Cache` instead, e.g. " - "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - - if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): - past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config)) - - if inputs_embeds is None: - inputs_embeds = self.wte(input_ids) - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device) - - # Attention mask. - # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel - if attention_mask is not None and attention_mask.ndim < 4: - attention_mask = attention_mask.view(batch_size, -1) - - # ------------------------------------------------------------- - # ORIGINAL CAUSAL CODE REMOVED BY YOU - # (kept as comment, no modification) - # ------------------------------------------------------------- - - # ------------------------------------------------------------- - # NEW CODE (bidirectional, padding-only mask) - # (kept exactly as you wrote) - # ------------------------------------------------------------- - if attention_mask is None: - attention_mask = torch.ones( - inputs_embeds.shape[:2], - device=inputs_embeds.device, - dtype=torch.long, - ) - - if not ( - isinstance(attention_mask, BlockMask) - or (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4) - ): - attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype) - # ------------------------------------------------------------- - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None - if self.config.add_cross_attention and encoder_hidden_states is not None: - encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() - encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) - if encoder_attention_mask is None: - encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - if _use_sdpa: - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - elif self._attn_implementation != "flash_attention_2": - encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) - else: - encoder_attention_mask = None - - head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if token_type_ids is not None: - token_type_embeds = self.wte(token_type_ids) - hidden_states = hidden_states + token_type_embeds - - hidden_states = self.drop(hidden_states) - - output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - - all_self_attentions = () if output_attentions else None - all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None - all_hidden_states = () if output_hidden_states else None - for i, block in enumerate(self.h): - # Model parallel - if self.model_parallel: - torch.cuda.set_device(hidden_states.device) - if isinstance(head_mask, torch.Tensor): - head_mask = head_mask.to(hidden_states.device) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - outputs = block( - hidden_states, - past_key_values if not (self.gradient_checkpointing and self.training) else None, - cache_position, - attention_mask, # (unchanged) pass your full-mask 4D mask - head_mask[i], - encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - **kwargs, - ) - - hidden_states = outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (outputs[1],) - if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[2],) - - # Model Parallel: If it's the last layer for that device, put things on the next device - if self.model_parallel: - for k, v in self.device_map.items(): - if i == v[-1] and "cuda:" + str(k) != self.last_device: - hidden_states = hidden_states.to("cuda:" + str(k + 1)) - - hidden_states = self.ln_f(hidden_states) - - hidden_states = hidden_states.view(output_shape) - # Add last hidden state - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - past_key_values = past_key_values if use_cache else None - if not return_dict: - return tuple( - v - for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] - if v is not None - ) - - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=past_key_values, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -class A2DGPT2LMHeadModel(transformers.GPT2LMHeadModel): - config: A2DGPT2Config - - def __init__(self, config): - transformers.GPT2PreTrainedModel.__init__(self, config) - self.transformer = A2DGPT2Model(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - - self.model_parallel = False - self.device_map = None - self.post_init() - - -transformers.AutoConfig.register("a2d-gpt2", A2DGPT2Config) -transformers.AutoModel.register(A2DGPT2Config, A2DGPT2LMHeadModel) -transformers.AutoModelForMaskedLM.register(A2DGPT2Config, A2DGPT2LMHeadModel) - - -if __name__ == "__main__": - import dllm - import torch - from transformers import AutoModel - - # Load a config from a local path (either a directory containing config.json, or the file itself) - config_path = dllm.utils.resolve_with_base_env( - "openai-community/gpt2", "BASE_MODELS_DIR" - ) - config = A2DGPT2Config.from_pretrained(config_path) - if hasattr(config, "auto_map"): - delattr(config, "auto_map") - if hasattr(config, "architectures"): - delattr(config, "architectures") - - torch.set_default_device("cuda") - model = A2DGPT2LMHeadModel(config) - model.save_pretrained("models-tmp/a2d-gpt2") - auto_model = AutoModel.from_pretrained("models-tmp/a2d-gpt2") diff --git a/dllm/pipelines/dream/eval.py b/dllm/pipelines/dream/eval.py index 9309aa09..ac1b52ca 100644 --- a/dllm/pipelines/dream/eval.py +++ b/dllm/pipelines/dream/eval.py @@ -214,9 +214,9 @@ def generate_until( # tokenize prompt_ids = [ - self.tokenizer( - p, return_tensors="pt", padding=False - ).input_ids.squeeze().to(self.device) + self.tokenizer(p, return_tensors="pt", padding=False) + .input_ids.squeeze() + .to(self.device) for p in prompts ] prompt_lens = [len(p_id) for p_id in prompt_ids] diff --git a/dllm/pipelines/dream/trainer.py b/dllm/pipelines/dream/trainer.py index f32e1386..f1c6066b 100644 --- a/dllm/pipelines/dream/trainer.py +++ b/dllm/pipelines/dream/trainer.py @@ -1,4 +1,5 @@ from typing import Any +from dataclasses import dataclass import torch @@ -6,21 +7,21 @@ def cart_weight( - masked_indices: torch.Tensor, t: torch.Tensor, p: float = 0.3 + masked_mask: torch.Tensor, t: torch.Tensor, p: float = 0.3 ) -> torch.Tensor: """ Optimized CART weight computation using matrix operations. Args: - masked_indices (torch.Tensor): (b, l) bool tensor indicating masked positions. + masked_mask (torch.Tensor): (b, l) bool tensor indicating masked positions. t (torch.Tensor): (b,) time steps (0-1 sampled uniformly). Not directly used in CART. p (float): Parameter of geometric distribution (0 < p <= 1). Returns: torch.Tensor: (b, l) float tensor of weights. """ - b, l = masked_indices.shape - device = masked_indices.device + b, l = masked_mask.shape + device = masked_mask.device idx = torch.arange(l, device=device) dist_matrix = (idx[None, :] - idx[:, None]).abs() - 1 @@ -31,9 +32,9 @@ def cart_weight( ).exp() * 0.5 # Ensure numerical stability geo_matrix.masked_fill_(dist_matrix == 0, 0.0) # ignore distance = 0 - valid_mask = (~masked_indices).float() # (b, l), 1 = unmasked + valid_mask = (~masked_mask).float() # (b, l), 1 = unmasked weights = valid_mask @ geo_matrix.T # (b, l) - weights = weights * masked_indices.float() + weights = weights * masked_mask.float() return weights @@ -42,25 +43,20 @@ class DreamTrainer(MDLMTrainer): DreamTrainer: specialization of MDLMTrainer for Dream training. """ - def __init__( - self, - loss_weight_type: str = "cart[geo_p:0.3]", - *args, - **kwargs, - ): - super().__init__( - loss_weight_type=loss_weight_type, - *args, - **kwargs, - ) + @dataclass + class DreamConfig(MDLMTrainer.MDLMConfig): + loss_weight_type: str = "cart[geo_p:0.3]" + right_shift_logits: bool = True - self.right_shift_logits = True + def __post_init__(self): + super().__post_init__() + assert self.right_shift_logits def _compute_loss_weights( self, t: torch.Tensor, inputs: dict[str, Any], - masked_indices: torch.Tensor, + masked_mask: torch.Tensor, *args, **kwargs, ) -> torch.Tensor: @@ -70,12 +66,12 @@ def _compute_loss_weights( match = re.search(r"geo_p:(0\.\d+)", self.loss_weight_type) geo_p = float(match.group(1)) if match else 0.3 - loss_weights = cart_weight(masked_indices, t, p=geo_p) + loss_weights = cart_weight(masked_mask, t, p=geo_p) else: loss_weights = super()._compute_loss_weights( t=t, inputs=inputs, - masked_indices=masked_indices, + masked_mask=masked_mask, *args, **kwargs, ) diff --git a/dllm/pipelines/editflow/trainer.py b/dllm/pipelines/editflow/trainer.py index 2635539f..285da50f 100644 --- a/dllm/pipelines/editflow/trainer.py +++ b/dllm/pipelines/editflow/trainer.py @@ -8,6 +8,7 @@ from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler from dllm.pipelines.editflow.utils import pad_1d +from dllm.utils.configs import TrainingArguments BLANK = -1 @@ -211,20 +212,25 @@ class EditFlowTrainer(transformers.Trainer): True intensities are w * rate_hat, with w = kappa_dot(t) / (1 - kappa(t)). """ + @dataclass + class EditFlowConfig(TrainingArguments): + time_epsilon: float = 1e-3 + normalize_per_position: bool = True + max_w: float = 20.0 + def __init__( self, - *args, + args: EditFlowConfig, scheduler: BaseKappaScheduler | None = None, - normalize_per_position: bool = True, - time_epsilon: float = 1e-3, - max_w: float | None = None, + *pargs, **kwargs, ): - self.scheduler = scheduler or CubicKappaScheduler() - self.normalize_per_position = normalize_per_position - self.time_epsilon = time_epsilon - self.max_w = max_w - super().__init__(*args, **kwargs) + super().__init__(args=args, *pargs, **kwargs) + + self.scheduler = scheduler if scheduler is not None else CubicKappaScheduler() + self.time_epsilon = args.time_epsilon + self.normalize_per_position = args.normalize_per_position + self.max_w = args.max_w def compute_loss( self, diff --git a/dllm/pipelines/llada/eval.py b/dllm/pipelines/llada/eval.py index 9cb1fe4a..cf46007b 100644 --- a/dllm/pipelines/llada/eval.py +++ b/dllm/pipelines/llada/eval.py @@ -364,9 +364,7 @@ def generate_until(self, requests: list[Instance]) -> list[str]: for instance in tqdm(requests, desc="Generating..."): context, gen_kwargs = instance.args # type: ignore prompt_ids = self.tokenizer(context)["input_ids"] - prompt = [ - torch.tensor(prompt_ids, device=self.device, dtype=torch.long) - ] + prompt = [torch.tensor(prompt_ids, device=self.device, dtype=torch.long)] stop_tokens = gen_kwargs["until"] generated_ids = sampler.sample( inputs=prompt, diff --git a/dllm/tools/preprocess_sft_dataset.py b/dllm/tools/preprocess_sft_dataset.py index b0b65d20..be8ffdfb 100644 --- a/dllm/tools/preprocess_sft_dataset.py +++ b/dllm/tools/preprocess_sft_dataset.py @@ -25,7 +25,7 @@ class ScriptArguments: """Preprocess SFT dataset.""" model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base" - sft_map_fn_path: str = "dllm.utils.default_mdlm_sft_map_fn" + sft_map_fn_path: str = "dllm.utils.default_sft_map_fn" dataset_args: str = "HuggingFaceTB/smoltalk" # required output_dir: str = "data/sft/llada/smoltalk" # required mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100 diff --git a/dllm/utils/configs.py b/dllm/utils/configs.py index f8e84359..b9e7916c 100644 --- a/dllm/utils/configs.py +++ b/dllm/utils/configs.py @@ -70,6 +70,7 @@ class TrainingArguments(transformers.TrainingArguments): def __post_init__(self): super().__post_init__() + self.run_name = self.run_name or self.output_dir if self.group_by_length: logger.info( "training_args.group_by_length=True: preprocessing " diff --git a/dllm/utils/data.py b/dllm/utils/data.py index 39d7794f..d69d6b7d 100644 --- a/dllm/utils/data.py +++ b/dllm/utils/data.py @@ -222,7 +222,7 @@ def clip_right(row): raise NotImplementedError -def default_mdlm_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict: +def default_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict: """ Build input_ids and labels for SFT. diff --git a/dllm/utils/models.py b/dllm/utils/models.py index 4878ad32..77460f0d 100644 --- a/dllm/utils/models.py +++ b/dllm/utils/models.py @@ -140,11 +140,7 @@ def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer: {% endif %} """ - elif issubclass(model_cls, LLaDAMoEModelLM): - tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) - tokenizer.eot_token = "<|role_end|>" - tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) - elif issubclass(model_cls, LLaDA2MoeModelLM): + elif issubclass(model_cls, (LLaDAMoEModelLM, LLaDA2MoeModelLM)): tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) tokenizer.eot_token = "<|role_end|>" tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) @@ -187,15 +183,10 @@ def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer: tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) tokenizer.eot_token = "<|eot_id|>" tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) - elif issubclass(model_cls, A2DQwen2LMHeadModel): - tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) - tokenizer.eot_token = "<|im_end|>" - tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) - elif issubclass(model_cls, A2DQwen3LMHeadModel): + elif issubclass(model_cls, (A2DQwen2LMHeadModel, A2DQwen3LMHeadModel)): tokenizer.add_special_tokens({"mask_token": "<|mask|>"}) tokenizer.eot_token = "<|im_end|>" tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token) - tokenizer.chat_template = "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' }}\n {{- '\n\n\n\n' }}\n{%- endif %}" else: print_main("no tokenizer customization for model class:", model_cls) return tokenizer diff --git a/examples/a2d/bd3lm/pt.py b/examples/a2d/bd3lm/pt.py index bc5dbe81..7866a79c 100644 --- a/examples/a2d/bd3lm/pt.py +++ b/examples/a2d/bd3lm/pt.py @@ -60,15 +60,14 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.BD3LMTrainer.BD3LMConfig): output_dir: str = "models/a2d/Qwen3-0.6B/bd3lm/tiny-shakespeare" num_train_epochs: int = 20 learning_rate: float = 1e-4 per_device_train_batch_size: int = 16 per_device_eval_batch_size: int = 16 - # a2d-specific + # bd3lm block_size: int = 32 - right_shift_logits: bool = False def train(): @@ -123,8 +122,6 @@ def train(): train_dataset=dataset["train"], eval_dataset=dataset.get("test", None), args=training_args, - block_size=training_args.block_size, - right_shift_logits=training_args.right_shift_logits, data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", diff --git a/examples/a2d/bd3lm/sft.py b/examples/a2d/bd3lm/sft.py index b97b329d..63e80053 100644 --- a/examples/a2d/bd3lm/sft.py +++ b/examples/a2d/bd3lm/sft.py @@ -55,16 +55,15 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.BD3LMTrainer.BD3LMConfig): output_dir: str = "models/a2d/Qwen3-0.6B/mdlm/alpaca" group_by_length: bool = True num_train_epochs: int = 20 learning_rate: float = 1e-4 per_device_train_batch_size: int = 16 per_device_eval_batch_size: int = 16 - # a2d-specific + # bd3lm block_size: int = 32 - right_shift_logits: bool = False def train(): @@ -89,7 +88,7 @@ def train(): ) if not data_args.load_preprocessed_data: map_fn = partial( - dllm.utils.default_mdlm_sft_map_fn, + dllm.utils.default_sft_map_fn, tokenizer=tokenizer, mask_prompt_loss=data_args.mask_prompt_loss, ) @@ -110,8 +109,6 @@ def train(): train_dataset=dataset["train"], eval_dataset=dataset.get("test", None), args=training_args, - block_size=training_args.block_size, - right_shift_logits=training_args.right_shift_logits, data_collator=( dllm.core.trainers.bd3lm.AppendEOSBlockWrapper( transformers.DataCollatorForSeq2Seq( diff --git a/examples/a2d/mdlm/pt.py b/examples/a2d/mdlm/pt.py index 4554ef83..43d86f2c 100644 --- a/examples/a2d/mdlm/pt.py +++ b/examples/a2d/mdlm/pt.py @@ -60,14 +60,12 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig): output_dir: str = "models/a2d/Qwen3-0.6B/mdlm/tiny-shakespeare" num_train_epochs: int = 20 learning_rate: float = 1e-4 per_device_train_batch_size: int = 16 per_device_eval_batch_size: int = 16 - # a2d-specific - right_shift_logits: bool = False def train(): @@ -122,7 +120,6 @@ def train(): train_dataset=dataset["train"], eval_dataset=dataset.get("test", None), args=training_args, - right_shift_logits=training_args.right_shift_logits, data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", diff --git a/examples/a2d/mdlm/sft.py b/examples/a2d/mdlm/sft.py index 69c7582c..f74ef33e 100644 --- a/examples/a2d/mdlm/sft.py +++ b/examples/a2d/mdlm/sft.py @@ -55,15 +55,13 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig): output_dir: str = "models/a2d/Qwen3-0.6B/mdlm/alpaca" group_by_length: bool = True learning_rate: float = 1e-4 num_train_epochs: int = 20 per_device_train_batch_size: int = 16 per_device_eval_batch_size: int = 16 - # a2d-specific - right_shift_logits: bool = False def train(): @@ -88,7 +86,7 @@ def train(): ) if not data_args.load_preprocessed_data: map_fn = partial( - dllm.utils.default_mdlm_sft_map_fn, + dllm.utils.default_sft_map_fn, tokenizer=tokenizer, mask_prompt_loss=data_args.mask_prompt_loss, ) @@ -109,7 +107,6 @@ def train(): train_dataset=dataset["train"], eval_dataset=dataset.get("test", None), args=training_args, - right_shift_logits=training_args.right_shift_logits, data_collator=( dllm.utils.NoAttentionMaskWrapper( # padded should be visible transformers.DataCollatorForSeq2Seq( diff --git a/examples/bert/pt.py b/examples/bert/pt.py index 302d6b34..39cb906b 100644 --- a/examples/bert/pt.py +++ b/examples/bert/pt.py @@ -60,7 +60,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig): output_dir: str = "models/ModernBERT-base/tiny-shakespeare" num_train_epochs: int = 20 learning_rate: float = 1e-4 diff --git a/examples/bert/sft.py b/examples/bert/sft.py index 038b9c44..878cfca2 100644 --- a/examples/bert/sft.py +++ b/examples/bert/sft.py @@ -55,7 +55,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig): output_dir: str = "models/ModernBERT-large/alpaca" group_by_length: bool = True num_train_epochs: int = 20 @@ -86,7 +86,7 @@ def train(): ) if not data_args.load_preprocessed_data: map_fn = partial( - dllm.utils.default_mdlm_sft_map_fn, + dllm.utils.default_sft_map_fn, tokenizer=tokenizer, mask_prompt_loss=data_args.mask_prompt_loss, ) diff --git a/examples/dream/README.md b/examples/dream/README.md index ce406ebb..7bcb34d3 100644 --- a/examples/dream/README.md +++ b/examples/dream/README.md @@ -88,7 +88,7 @@ We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/D # Preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training) python dllm/tools/preprocess_sft_dataset.py \ --model_name_or_path "Dream-org/Dream-v0-Base-7B" \ - --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \ + --sft_map_fn_path "dllm.utils.default_sft_map_fn" \ --dataset_args "allenai/tulu-3-sft-mixture" \ --output_dir "data/sft/dream/tulu-3-sft-mixture" \ --num_proc 64 diff --git a/examples/dream/pt.py b/examples/dream/pt.py index 9a44a33e..5247213c 100644 --- a/examples/dream/pt.py +++ b/examples/dream/pt.py @@ -65,7 +65,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.pipelines.dream.DreamTrainer.DreamConfig): output_dir: str = ( "models/Dream-v0-Base-7B/dclm-baseline-1.0[train:10_000_000,test:10_000]" ) @@ -142,7 +142,6 @@ def train(): train_dataset=dataset["train"], eval_dataset=dataset.get("test", None), args=training_args, - loss_weight_type=training_args.loss_weight_type, data_collator=transformers.DataCollatorForSeq2Seq( tokenizer, return_tensors="pt", diff --git a/examples/dream/sft.py b/examples/dream/sft.py index a48aa9e9..98a44eb1 100644 --- a/examples/dream/sft.py +++ b/examples/dream/sft.py @@ -75,7 +75,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.pipelines.dream.DreamTrainer.DreamConfig): output_dir: str = ( "models/Dream-v0-Base-7B/tulu-3-sft-mixture[train:10000,test:1000]" ) @@ -120,7 +120,7 @@ def train(): ) if not data_args.load_preprocessed_data: map_fn = partial( - dllm.utils.default_mdlm_sft_map_fn, + dllm.utils.default_sft_map_fn, tokenizer=tokenizer, mask_prompt_loss=data_args.mask_prompt_loss, ) @@ -141,7 +141,6 @@ def train(): train_dataset=dataset["train"], eval_dataset=dataset.get("test", None), args=training_args, - loss_weight_type=training_args.loss_weight_type, data_collator=dream.utils.DreamSFTCollator( tokenizer, return_tensors="pt", diff --git a/examples/editflow/pt.py b/examples/editflow/pt.py index 0cbb463a..a6b0927c 100644 --- a/examples/editflow/pt.py +++ b/examples/editflow/pt.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field import accelerate +import transformers import dllm from dllm.pipelines import editflow @@ -31,7 +32,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(editflow.EditFlowTrainer.EditFlowConfig): output_dir: str = None # overwrite this num_train_epochs: int = 10 learning_rate: float = 1e-4 @@ -43,18 +44,10 @@ class TrainingArguments(dllm.utils.TrainingArguments): metadata={ "help": ( "The scheduler class controlling κ(t). " - "Available options: see `dllm/utils/schedulers/kappa.py`" + "Available options: see `dllm/core/schedulers/kappa.py`" ) }, ) - normalize_per_position: bool = field( - default=True, - metadata={"help": "Whether to normalize the loss per position."}, - ) - max_w: float = field( - default=20.0, - metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."}, - ) x0_sampler: str = field( default="masks[length:128]", metadata={ @@ -66,11 +59,12 @@ class TrainingArguments(dllm.utils.TrainingArguments): ) -def train( - model_args: ModelArguments, - data_args: DataArguments, - training_args: TrainingArguments, -): +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() # necessary when batch does not contain "labels" field training_args.label_names = [] # necessary when batch contains customized fields @@ -129,11 +123,13 @@ def _no_flops(*args, **kwargs): scheduler=dllm.core.schedulers.make_kappa_scheduler( training_args.scheduler_cls ), - normalize_per_position=training_args.normalize_per_position, - max_w=training_args.max_w, ) trainer.train() trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) trainer.processing_class.save_pretrained( os.path.join(training_args.output_dir, "checkpoint-final") ) + + +if __name__ == "__main__": + train() diff --git a/examples/editflow/sft.py b/examples/editflow/sft.py index 1daff739..c7ec1cbf 100644 --- a/examples/editflow/sft.py +++ b/examples/editflow/sft.py @@ -3,6 +3,7 @@ from functools import partial import accelerate +import transformers import dllm from dllm.pipelines import editflow @@ -26,7 +27,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(editflow.EditFlowTrainer.EditFlowConfig): output_dir: str = None # overwrite this num_train_epochs: float = 10 learning_rate: float = 1e-4 @@ -38,18 +39,10 @@ class TrainingArguments(dllm.utils.TrainingArguments): metadata={ "help": ( "The scheduler class controlling κ(t). " - "Available options: see `dllm/utils/schedulers/kappa.py`" + "Available options: see `dllm/core/schedulers/kappa.py`" ) }, ) - normalize_per_position: bool = field( - default=True, - metadata={"help": "Whether to normalize the loss per position."}, - ) - max_w: float = field( - default=20.0, - metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."}, - ) x0_sampler: str = field( default="masks[length:64]", metadata={ @@ -88,11 +81,12 @@ def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict: return {"input_ids": prompt_response_tokens} -def train( - model_args: ModelArguments, - data_args: DataArguments, - training_args: TrainingArguments, -): +def train(): + # ----- Argument parsing ------------------------------------------------------- + parser = transformers.HfArgumentParser( + (ModelArguments, DataArguments, TrainingArguments) + ) + model_args, data_args, training_args = parser.parse_args_into_dataclasses() # necessary when batch does not contain "labels" field training_args.label_names = [] # necessary when batch contains customized fields @@ -146,11 +140,13 @@ def _no_flops(*args, **kwargs): scheduler=dllm.core.schedulers.make_kappa_scheduler( training_args.scheduler_cls ), - normalize_per_position=training_args.normalize_per_position, - max_w=training_args.max_w, ) trainer.train() trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final")) trainer.processing_class.save_pretrained( os.path.join(training_args.output_dir, "checkpoint-final") ) + + +if __name__ == "__main__": + train() diff --git a/examples/llada/README.md b/examples/llada/README.md index d8626d31..494fa22f 100644 --- a/examples/llada/README.md +++ b/examples/llada/README.md @@ -109,7 +109,7 @@ Though LLaDA is trained on proprietary data, we tried our best to reproduce [`LL # Preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training) python dllm/tools/preprocess_sft_dataset.py \ --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \ - --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \ + --sft_map_fn_path "dllm.utils.default_sft_map_fn" \ --dataset_args "allenai/tulu-3-sft-mixture" \ --output_dir "data/sft/llada/tulu-3-sft-mixture" \ --num_proc 64 diff --git a/examples/llada/pt.py b/examples/llada/pt.py index 246f7097..06243d55 100644 --- a/examples/llada/pt.py +++ b/examples/llada/pt.py @@ -67,7 +67,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig): output_dir: str = ( "models/LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]" ) diff --git a/examples/llada/sft.py b/examples/llada/sft.py index e0830e5f..b3a9f1e0 100644 --- a/examples/llada/sft.py +++ b/examples/llada/sft.py @@ -55,7 +55,7 @@ class DataArguments(dllm.utils.DataArguments): @dataclass -class TrainingArguments(dllm.utils.TrainingArguments): +class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig): output_dir: str = "models/LLaDA-8B-Base/tulu-3-sft-mixture[train:10000,test:1000]" group_by_length: bool = True num_train_epochs: float = 5 @@ -86,7 +86,7 @@ def train(): ) if not data_args.load_preprocessed_data: map_fn = partial( - dllm.utils.default_mdlm_sft_map_fn, + dllm.utils.default_sft_map_fn, tokenizer=tokenizer, mask_prompt_loss=data_args.mask_prompt_loss, ) diff --git a/pyproject.toml b/pyproject.toml index 9f134248..288d2284 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "peft==0.17.1", "datasets==4.2.0", "sentencepiece==0.2.0", + "torchmetrics", "tyro", "wandb", "omegaconf", @@ -33,6 +34,10 @@ optional = [ "vllm==0.8.5.post1", "flash-attn==2.8.3", ] +rl = [ + "trl==0.26.0", + "math_verify==0.8.0", +] [tool.black] line-length = 88 diff --git a/scripts/tests/test_attention.py b/scripts/tests/test_attention.py index 4a690426..e1819ca1 100644 --- a/scripts/tests/test_attention.py +++ b/scripts/tests/test_attention.py @@ -662,7 +662,7 @@ def test_bd3lm_staircase_attention_kvcache_equivalence( (A) in one full 8-token forward pass (B) in two incremental passes (4 tokens → KV cache → 4 tokens) """ - from dllm.core.samplers.bd3lm import build_staircase_attention_mask + from dllm.core.samplers.bd3lm import _prepare_for_sampling torch.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -699,7 +699,7 @@ def test_bd3lm_staircase_attention_kvcache_equivalence( # ------------------------------ # 3. Build staircase mask + positions for the full sequence # ------------------------------ - attn_full, pos_full = build_staircase_attention_mask( + attn_full, pos_full = _prepare_for_sampling( x_full, block_size=block_size, pad_token_id=pad_token_id ) # attn_full: [1, 1, 8, 8] @@ -808,7 +808,7 @@ def test_bd3lm_concat_equivalence_when_noised_equals_input( NOTE: We set block_size == seq_len so x_t tokens attend only within x_t (single block), making the first-half computation equivalent to a standard full-attention forward. """ - from dllm.core.trainers.bd3lm import block_diff_mask + from dllm.core.trainers.bd3lm import _bd3lm_attention_mask torch.manual_seed(0) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -858,7 +858,7 @@ def test_bd3lm_concat_equivalence_when_noised_equals_input( pos_cat = torch.cat([pos, pos], dim=1) # [1, 2L] L2 = 2 * seq_len - attn_bd = block_diff_mask( + attn_bd = _bd3lm_attention_mask( b=None, h=None, q_idx=torch.arange(L2, device=device)[:, None],