diff --git a/README.md b/README.md
index e87bc6eb..724f5d7d 100644
--- a/README.md
+++ b/README.md
@@ -219,7 +219,7 @@ See [Features](#features) for specific training recipes.
# Preprocess SFT data
python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
- --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \
+ --sft_map_fn_path "dllm.utils.default_sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/llada/tulu-3-sft-mixture" \
--num_proc 64
@@ -238,7 +238,7 @@ See [Features](#features) for specific training recipes.
# Preprocess SFT data
+ python dllm/tools/preprocess_sft_dataset.py \
+ --model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
- + --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \
+ + --sft_map_fn_path "dllm.utils.default_sft_map_fn" \
+ --dataset_args "allenai/tulu-3-sft-mixture" \
+ --output_dir "data/sft/llada/tulu-3-sft-mixture" \
+ --num_proc 64
diff --git a/dllm/core/samplers/bd3lm.py b/dllm/core/samplers/bd3lm.py
index 15d30b1d..ec2a3f96 100644
--- a/dllm/core/samplers/bd3lm.py
+++ b/dllm/core/samplers/bd3lm.py
@@ -13,7 +13,7 @@
from dllm.core.samplers.utils import add_gumbel_noise, get_num_transfer_tokens
-def build_staircase_attention_mask(
+def _prepare_for_sampling(
x: torch.Tensor,
block_size: int,
pad_token_id: int,
@@ -81,7 +81,7 @@ def build_staircase_attention_mask(
return attn_mask, position_ids
-def diffusion_step_block(
+def _diffusion_step_block(
logits: torch.Tensor, # [B, L, V]
x_block: torch.Tensor, # [B, L]
mask_block: torch.Tensor, # [B, L] bool
@@ -201,9 +201,10 @@ def sample(
mask_id = self.tokenizer.mask_token_id
bos_id = self.tokenizer.bos_token_id
pad_id = self.tokenizer.pad_token_id # used as padding here
+ eos_id = self.tokenizer.eos_token_id
# ---- normalize inputs to tensors ----
- # If right_shift_logits is true and a sequence has length 0, replace that sequence with [eos].
+ # If right_shift_logits is true and a sequence has length 0, replace that sequence with [bos].
if right_shift_logits:
inputs = [
[bos_id] if isinstance(p, list) and len(p) == 0 else p for p in inputs
@@ -228,16 +229,21 @@ def sample(
# ==========================================================
# 1) Initialize with prompt only (left padded with pad_id)
+ # pad prefix length to a multiple of block_size
# ==========================================================
+ padded_prompt_len = (
+ (max_prompt_len + block_size - 1) // block_size
+ ) * block_size
+
x = torch.full(
- (B, max_prompt_len),
+ (B, padded_prompt_len),
pad_id,
dtype=torch.long,
device=self.model.device,
)
for b, p in enumerate(inputs):
L = prompt_lens[b]
- offset = max_prompt_len - L # left padding
+ offset = padded_prompt_len - L # left padding
x[b, offset : offset + L] = p
# Tokens considered "given" for unconditional branch in CFG.
@@ -248,6 +254,9 @@ def sample(
)
unmasked_index = unmasked_index & (~keep_mask)
+ # track done per sequence (EOS)
+ done = torch.zeros((B,), dtype=torch.bool, device=self.model.device)
+
# ---- block scheduling ----
num_blocks = math.ceil(max_new_tokens / block_size)
if steps_per_block is None:
@@ -260,15 +269,13 @@ def sample(
# 2) Block-by-block generation loop
# ==========================================================
for b_idx in range(num_blocks):
- # Align sampling block to physical block boundaries
+ if done.all():
+ break
+
T_prefix = x.shape[1] # current total length before appending this block
- offset = T_prefix % block_size
- if offset == 0:
- block_room = block_size
- else:
- block_room = block_size - offset
- cur_block_len = min(block_room, max_new_tokens - generated)
+ # With padded_prompt_len aligned, we always append whole blocks (except possibly final)
+ cur_block_len = min(block_size, max_new_tokens - generated)
if cur_block_len <= 0:
break
@@ -278,7 +285,7 @@ def sample(
x_prefix = x # [B, T_prefix]
B_cur, T_prefix = x_prefix.shape
- prefix_attn, prefix_pos = build_staircase_attention_mask(
+ prefix_attn, prefix_pos = _prepare_for_sampling(
x=x_prefix,
block_size=block_size,
pad_token_id=pad_id,
@@ -317,6 +324,13 @@ def sample(
new_block = torch.full(
(B, cur_block_len), mask_id, dtype=torch.long, device=self.model.device
)
+ # if done.any():
+ # new_block = torch.where(
+ # done.unsqueeze(1),
+ # torch.full_like(new_block, pad_id),
+ # new_block,
+ # )
+
x = torch.cat([x, new_block], dim=1) # [B, T_prefix + cur_block_len]
unmasked_index = torch.cat(
@@ -342,7 +356,7 @@ def sample(
effective_steps = num_transfer_tokens.size(1)
# Full staircase attention mask + pos for prefix + current block
- full_attention_mask, full_position_ids = build_staircase_attention_mask(
+ full_attention_mask, full_position_ids = _prepare_for_sampling(
x=x,
block_size=block_size,
pad_token_id=pad_id,
@@ -406,7 +420,7 @@ def sample(
logits_block = shifted
# ---- One diffusion step over this block ----
- x_block_updated = diffusion_step_block(
+ x_block_updated = _diffusion_step_block(
logits=logits_block,
x_block=x_block,
mask_block=mask_block,
@@ -421,8 +435,10 @@ def sample(
if histories is not None:
histories.append(x.clone())
- if self.tokenizer.eos_token_id in x[:, T_prefix:T_total]:
- break
+ # per-sequence EOS stopping (after finishing denoising the block)
+ if eos_id is not None:
+ eos_in_block = (x[:, T_prefix:T_total] == eos_id).any(dim=1)
+ done = done | eos_in_block
generated += cur_block_len
@@ -437,7 +453,7 @@ def sample(
@torch.no_grad()
def infill(
self,
- inputs: list[torch.Tensor, list],
+ inputs: list[torch.Tensor | list],
config: SamplerConfig | None = None,
**kwargs,
) -> SamplerOutput:
diff --git a/dllm/core/trainers/bd3lm.py b/dllm/core/trainers/bd3lm.py
index 5beb8c36..b6f9b95b 100644
--- a/dllm/core/trainers/bd3lm.py
+++ b/dllm/core/trainers/bd3lm.py
@@ -17,7 +17,6 @@
from dllm.utils.collators import CollatorWrapper
from .mdlm import MDLMTrainer
-from .utils import EpochPPLMeter
@dataclass
@@ -40,7 +39,7 @@ def before(self, features):
return features
-def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
+def _bd3lm_attention_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
"""
Constructs the specialized block diffusion attention mask for training
composed of three masks:
@@ -85,17 +84,18 @@ def block_diff_mask(b, h, q_idx, kv_idx, block_size=None, n=None):
class BD3LMTrainer(MDLMTrainer):
+ @dataclass
+ class BD3LMConfig(MDLMTrainer.MDLMConfig):
+ block_size: int = 32
+
def __init__(
self,
- block_size: int = 32,
- *args,
+ args: BD3LMConfig,
+ *pargs,
**kwargs,
):
- super().__init__(*args, **kwargs)
- self.block_size = block_size
-
- self.epoch_meter = EpochPPLMeter(self, train_prefix="train", eval_prefix="eval")
- self.add_callback(self.epoch_meter)
+ super().__init__(args=args, *pargs, **kwargs)
+ self.block_size = args.block_size
def compute_loss(
self,
@@ -126,7 +126,7 @@ def compute_loss(
inputs.get("attention_mask", None),
)
b, l = input_ids.shape
- token_cnt_per_seq = torch.sum(labels != -100, dim=1, keepdim=True) # [b, 1]
+ maskable_mask = labels != -100 # [b, l]
# === 1. Sample diffusion timesteps ===
# Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0.
@@ -139,12 +139,12 @@ def compute_loss(
# === 2. Apply stochastic masking ===
# Tokens are masked independently according to p_mask(t).
# Positions with label = -100 are excluded (ignored in loss).
- masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & (
- labels != -100
- )
+ masked_mask = (
+ torch.rand((b, l), device=input_ids.device) < p_mask
+ ) & maskable_mask
# Replace masked tokens with the special [MASK] token.
noised_input_ids = torch.where(
- masked_indices, self.processing_class.mask_token_id, input_ids
+ masked_mask, self.processing_class.mask_token_id, input_ids
)
# === 3. Forward pass through the model (block-diffusion) ===
@@ -155,7 +155,7 @@ def compute_loss(
# [TODO]: others like flash attention 2
if self.accelerator.unwrap_model(model).config._attn_implementation == "sdpa":
- attention_mask = block_diff_mask(
+ attention_mask = _bd3lm_attention_mask(
b=None,
h=None,
q_idx=torch.arange(l * 2)[:, None],
@@ -174,7 +174,7 @@ def compute_loss(
from torch.nn.attention.flex_attention import create_block_mask
attention_mask = create_block_mask(
- partial(block_diff_mask, block_size=self.block_size, n=l),
+ partial(_bd3lm_attention_mask, block_size=self.block_size, n=l),
B=None,
H=None,
Q_LEN=l * 2,
@@ -201,46 +201,51 @@ def compute_loss(
# === 4. Handle degenerate cases (no tokens masked) ===
# If no positions were masked, return a zero loss to keep gradients valid.
# This step is necessary for Deepspeed Zero-{2,3}
- if not masked_indices.any():
- self.epoch_meter.update(
- split="train" if model.training else "eval",
- nll_sum=logits.sum() * 0.0,
- token_cnt=token_cnt_per_seq.sum(),
- )
- return (
- (logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0
+ if not masked_mask.any():
+ zero = logits.sum() * 0.0 # scalar, grad-safe
+ self.meter.update(
+ split="train" if model.training else "eval",
+ value=torch.zeros_like(maskable_mask, dtype=logits.dtype),
+ weight=maskable_mask.to(dtype=logits.dtype).detach(),
)
+ return (zero, outputs) if return_outputs else zero
# === 5. Compute per-token loss weights ===
# Depending on the configuration, weights may depend on timestep t
# (e.g., scheduler-based) or be uniform (ones).
loss_weights = self._compute_loss_weights(
- t=t, inputs=inputs, masked_indices=masked_indices
+ t=t, inputs=inputs, masked_mask=masked_mask
)
# === 6. Compute weighted cross-entropy ===
- # Only masked tokens contribute to the loss.
- assert (input_ids[masked_indices] == labels[masked_indices]).all()
- token_loss = F.cross_entropy(
- logits[masked_indices], input_ids[masked_indices], reduction="none"
+ # Sanity check: ensure input_ids and labels match at valid positions
+ assert (
+ input_ids[maskable_mask] == labels[maskable_mask]
+ ).all(), "Mismatch between input_ids and labels at valid positions"
+
+ token_nll = F.cross_entropy(
+ logits.transpose(1, 2), # [b, V, l]
+ input_ids, # [b, l]
+ reduction="none", # [b, l]
+ )
+ token_nll = token_nll * loss_weights * masked_mask.to(token_nll.dtype) # [b, l]
+
+ self.meter.update(
+ split="train" if model.training else "eval",
+ value=token_nll.detach(),
+ weight=maskable_mask.to(dtype=logits.dtype).detach(),
)
- token_loss = token_loss * loss_weights[masked_indices]
# === 7. Normalize loss ===
- if self.loss_normalization_type == "batch":
- token_loss_normalized = token_loss / b
- elif self.loss_normalization_type == "sequence":
- token_loss_normalized = token_loss / token_cnt_per_seq.expand(-1, l)[masked_indices] / b
- elif self.loss_normalization_type == "token":
- token_loss_normalized = token_loss / token_cnt_per_seq.sum()
+ if self.loss_norm_type == "batch":
+ token_nll /= b
+ elif self.loss_norm_type == "sequence":
+ token_nll /= maskable_mask.sum(-1, keepdim=True).clamp_min(1) * b
+ elif self.loss_norm_type == "token":
+ token_nll /= maskable_mask.sum().clamp_min(1)
else:
- raise ValueError("Invalid loss_normalization_type.")
- loss = token_loss_normalized.sum()
+ raise ValueError("Invalid loss_norm_type.")
+ loss = token_nll.sum()
# === 8. Return final loss (and optionally model outputs) ===
- self.epoch_meter.update(
- split="train" if model.training else "eval",
- nll_sum=token_loss.sum(),
- token_cnt=token_cnt_per_seq.sum(),
- ) # `nll_sum / token_cnt` is equivalent to `loss` when `self.loss_normalization_type == "token"``
return (loss, outputs) if return_outputs else loss
diff --git a/dllm/core/trainers/mdlm.py b/dllm/core/trainers/mdlm.py
index 316669d8..5bb7d771 100644
--- a/dllm/core/trainers/mdlm.py
+++ b/dllm/core/trainers/mdlm.py
@@ -9,6 +9,7 @@
"""
from typing import Any
+from dataclasses import dataclass
import torch
import torch.nn as nn
@@ -16,35 +17,48 @@
import transformers
from dllm.core.schedulers import BaseAlphaScheduler, LinearAlphaScheduler
+from dllm.utils.configs import TrainingArguments
from dllm.utils.data import prepend_bos
-from .utils import EpochPPLMeter
+from .utils import NLLMetric, PerplexityMetric, OnEvaluateMetricsCallback
class MDLMTrainer(transformers.Trainer):
+ @dataclass
+ class MDLMConfig(TrainingArguments):
+ time_epsilon: float = 1e-3
+ loss_weight_type: str = "scheduler" # "scheduler", "uniform"
+ loss_norm_type: str = "sequence" # "batch", "sequence", "token"
+ right_shift_logits: bool = False
+
def __init__(
self,
+ args: MDLMConfig,
scheduler: BaseAlphaScheduler | None = None,
- time_epsilon: float = 1e-3,
- loss_weight_type: str = "scheduler", # "scheduler", "uniform"
- loss_normalization_type: str = "sequence", # "batch", "sequence", "token"
- right_shift_logits: bool = False,
- *args,
+ *pargs,
**kwargs,
):
- super().__init__(*args, **kwargs)
+ super().__init__(args=args, *pargs, **kwargs)
- if not (0.0 < time_epsilon < 1.0):
+ if not (0.0 < args.time_epsilon < 1.0):
raise ValueError("time_epsilon must be in (0, 1)")
- self.scheduler = scheduler or LinearAlphaScheduler()
- self.time_epsilon = time_epsilon
- self.loss_weight_type = loss_weight_type
- self.loss_normalization_type = loss_normalization_type
- self.right_shift_logits = right_shift_logits
-
- self.epoch_meter = EpochPPLMeter(self, train_prefix="train", eval_prefix="eval")
- self.add_callback(self.epoch_meter)
+ self.scheduler = scheduler if scheduler is not None else LinearAlphaScheduler()
+ self.time_epsilon = args.time_epsilon
+ self.loss_weight_type = args.loss_weight_type
+ self.loss_norm_type = args.loss_norm_type
+ self.right_shift_logits = args.right_shift_logits
+
+ # self.epoch_meter = EpochPPLMeter(self)
+ self.meter = OnEvaluateMetricsCallback(
+ trainer=self,
+ splits=("train", "eval"),
+ metrics_map={
+ "diff_nll": NLLMetric(),
+ "diff_ppl": PerplexityMetric(),
+ },
+ )
+ self.add_callback(self.meter)
def _preprocess_inputs(self, inputs):
if self.right_shift_logits:
@@ -133,7 +147,7 @@ def compute_loss(
inputs.get("attention_mask", None),
)
b, l = input_ids.shape
- token_cnt_per_seq = torch.sum(labels != -100, dim=1, keepdim=True) # [b, 1]
+ maskable_mask = labels != -100 # [b, l]
# === 1. Sample diffusion timesteps ===
# Each example draws a random timestep t ∈ [ε, 1), where ε avoids degenerate values near 0.
@@ -146,12 +160,12 @@ def compute_loss(
# === 2. Apply stochastic masking ===
# Tokens are masked independently according to p_mask(t).
# Positions with label = -100 are excluded (ignored in loss).
- masked_indices = (torch.rand((b, l), device=input_ids.device) < p_mask) & (
- labels != -100
- )
+ masked_mask = (
+ torch.rand((b, l), device=input_ids.device) < p_mask
+ ) & maskable_mask
# Replace masked tokens with the special [MASK] token.
noised_input_ids = torch.where(
- masked_indices, self.processing_class.mask_token_id, input_ids
+ masked_mask, self.processing_class.mask_token_id, input_ids
)
# === 3. Forward pass through the model ===
@@ -163,46 +177,51 @@ def compute_loss(
# === 4. Handle degenerate cases (no tokens masked) ===
# If no positions were masked, return a zero loss to keep gradients valid.
# This step is necessary for Deepspeed Zero-{2,3}
- if not masked_indices.any():
- self.epoch_meter.update(
- split="train" if model.training else "eval",
- nll_sum=logits.sum() * 0.0,
- token_cnt=token_cnt_per_seq.sum(),
- )
- return (
- (logits.sum() * 0.0, outputs) if return_outputs else logits.sum() * 0.0
+ if not masked_mask.any():
+ zero = logits.sum() * 0.0 # scalar, grad-safe
+ self.meter.update(
+ split="train" if model.training else "eval",
+ value=torch.zeros_like(maskable_mask, dtype=logits.dtype),
+ weight=maskable_mask.to(dtype=logits.dtype).detach(),
)
+ return (zero, outputs) if return_outputs else zero
# === 5. Compute per-token loss weights ===
# Depending on the configuration, weights may depend on timestep t
# (e.g., scheduler-based) or be uniform (ones).
loss_weights = self._compute_loss_weights(
- t=t, inputs=inputs, masked_indices=masked_indices
+ t=t, inputs=inputs, masked_mask=masked_mask
)
# === 6. Compute weighted cross-entropy ===
- # Only masked tokens contribute to the loss.
- assert (input_ids[masked_indices] == labels[masked_indices]).all()
- token_loss = F.cross_entropy(
- logits[masked_indices], input_ids[masked_indices], reduction="none"
+ # Sanity check: ensure input_ids and labels match at valid positions
+ assert (
+ input_ids[maskable_mask] == labels[maskable_mask]
+ ).all(), "Mismatch between input_ids and labels at valid positions"
+
+ token_nll = F.cross_entropy(
+ logits.transpose(1, 2), # [b, V, l]
+ input_ids, # [b, l]
+ reduction="none", # [b, l]
+ )
+ token_nll = token_nll * loss_weights * masked_mask.to(token_nll.dtype) # [b, l]
+
+ self.meter.update(
+ split="train" if model.training else "eval",
+ value=token_nll.detach(),
+ weight=maskable_mask.to(dtype=logits.dtype).detach(),
)
- token_loss = token_loss * loss_weights[masked_indices]
# === 7. Normalize loss ===
- if self.loss_normalization_type == "batch":
- token_loss_normalized = token_loss / b
- elif self.loss_normalization_type == "sequence":
- token_loss_normalized = token_loss / token_cnt_per_seq.expand(-1, l)[masked_indices] / b
- elif self.loss_normalization_type == "token":
- token_loss_normalized = token_loss / token_cnt_per_seq.sum()
+ if self.loss_norm_type == "batch":
+ token_nll /= b
+ elif self.loss_norm_type == "sequence":
+ token_nll /= maskable_mask.sum(-1, keepdim=True).clamp_min(1) * b
+ elif self.loss_norm_type == "token":
+ token_nll /= maskable_mask.sum().clamp_min(1)
else:
- raise ValueError("Invalid loss_normalization_type.")
- loss = token_loss_normalized.sum()
+ raise ValueError("Invalid loss_norm_type.")
+ loss = token_nll.sum()
# === 8. Return final loss (and optionally model outputs) ===
- self.epoch_meter.update(
- split="train" if model.training else "eval",
- nll_sum=token_loss.sum(),
- token_cnt=token_cnt_per_seq.sum(),
- ) # `nll_sum / token_cnt` is equivalent to `loss` when `self.loss_normalization_type == "token"``
return (loss, outputs) if return_outputs else loss
diff --git a/dllm/core/trainers/utils.py b/dllm/core/trainers/utils.py
index 2cb9ceff..ce150437 100644
--- a/dllm/core/trainers/utils.py
+++ b/dllm/core/trainers/utils.py
@@ -1,30 +1,18 @@
import math
-
import torch
import transformers
class EpochPPLMeter(transformers.TrainerCallback):
"""
- Keeps running sums for dataset-level NLL/token and logs PPL once per epoch.
-
- Usage:
- - Trainer calls: self.ppl_meter.update(split, nll_sum, token_cnt)
- - Callback hooks:
- * on_epoch_begin: reset train accumulators
- * on_epoch_end: finalize+log train PPL
- * on_evaluate: finalize+log eval PPL (one per evaluate call)
+ Keeps running sums for dataset-level NLL/token and logs PPL.
+ Convention:
+ - Train: keys are unprefixed, e.g. "diff_nll", "diff_ppl"
+ - Eval : keys are prefixed with "eval_", e.g. "eval_diff_nll", "eval_diff_ppl"
"""
- def __init__(
- self,
- trainer: "transformers.Trainer",
- train_prefix: str = "train",
- eval_prefix: str = "eval",
- ):
+ def __init__(self, trainer: "transformers.Trainer"):
self.trainer = trainer
- self.train_prefix = train_prefix
- self.eval_prefix = eval_prefix
self._train_nll_sum = 0.0
self._train_token_cnt = 0.0
@@ -41,8 +29,9 @@ def reset(self, split: str) -> None:
else:
raise ValueError(f"Unknown split={split}")
- def update(self, split: str, nll_sum: torch.Tensor, token_cnt: torch.Tensor) -> None:
- # detach -> float64 -> python float
+ def update(
+ self, split: str, nll_sum: torch.Tensor, token_cnt: torch.Tensor
+ ) -> None:
nll_sum_f = float(nll_sum.detach().double().cpu().item())
tok_cnt_f = float(token_cnt.detach().double().cpu().item())
@@ -60,9 +49,8 @@ def _finalize(self, split: str):
All-reduce (sum) across processes, then compute:
mean_nll = total_nll / total_tokens
ppl = exp(mean_nll)
-
- Returns (mean_nll, ppl) as python floats, or (None, None) if no tokens.
- Also resets that split after finalizing.
+ Returns (mean_nll, ppl) or (None, None) if no tokens.
+ Resets the split accumulators when called.
"""
if split == "train":
local_nll, local_tok = self._train_nll_sum, self._train_token_cnt
@@ -93,22 +81,44 @@ def _finalize(self, split: str):
# ---- callback hooks ----
- def on_epoch_begin(self, args, state, control, **kwargs):
+ def on_train_begin(self, args, state, control, **kwargs):
self.reset("train")
return control
- def on_epoch_end(self, args, state, control, **kwargs):
- mean_nll, ppl = self._finalize("train")
- if mean_nll is not None and self.trainer.is_world_process_zero():
- logs = {f"{self.train_prefix}_nll": mean_nll, f"{self.train_prefix}_ppl": ppl}
- self.trainer.log(logs)
- print(f"[epoch {state.epoch}] {self.train_prefix}_nll={mean_nll:.6f} {self.train_prefix}_ppl={ppl:.6f}")
+ def on_evaluate_begin(self, args, state, control, **kwargs):
+ self.reset("eval")
return control
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
- mean_nll, ppl = self._finalize("eval")
- if mean_nll is not None and self.trainer.is_world_process_zero():
- logs = {f"{self.eval_prefix}_nll": mean_nll, f"{self.eval_prefix}_ppl": ppl}
- self.trainer.log(logs)
- print(f"[epoch {state.epoch}] {self.eval_prefix}_nll={mean_nll:.6f} {self.eval_prefix}_ppl={ppl:.6f}")
+ train_mean_nll, train_ppl = self._finalize("train")
+ eval_mean_nll, eval_ppl = self._finalize("eval")
+
+ if self.trainer.is_world_process_zero():
+ logs = {}
+
+ # TRAIN: NO "train_" prefix
+ if train_mean_nll is not None:
+ logs.update(
+ {
+ "diff_nll": train_mean_nll,
+ "diff_ppl": train_ppl,
+ }
+ )
+
+ # EVAL: MUST be "eval_" prefixed
+ if eval_mean_nll is not None:
+ logs.update(
+ {
+ "eval_diff_nll": eval_mean_nll,
+ "eval_diff_ppl": eval_ppl,
+ }
+ )
+
+ if logs:
+ self.trainer.log(logs)
+ print(
+ f"[step {state.global_step} epoch {state.epoch}] "
+ + " ".join(f"{k}={v:.6f}" for k, v in logs.items())
+ )
+
return control
diff --git a/dllm/core/trainers/utils/__init__.py b/dllm/core/trainers/utils/__init__.py
new file mode 100644
index 00000000..b01daf8b
--- /dev/null
+++ b/dllm/core/trainers/utils/__init__.py
@@ -0,0 +1,3 @@
+from . import meters, metrics
+from .meters import *
+from .metrics import *
diff --git a/dllm/core/trainers/utils/meters.py b/dllm/core/trainers/utils/meters.py
new file mode 100644
index 00000000..aaf5d99f
--- /dev/null
+++ b/dllm/core/trainers/utils/meters.py
@@ -0,0 +1,187 @@
+from __future__ import annotations
+
+from typing import Any, Dict, Iterable, Optional
+import copy
+
+import torch
+import transformers
+import torchmetrics
+
+
+def _ddp_initialized() -> bool:
+ return torch.distributed.is_available() and torch.distributed.is_initialized()
+
+
+class BaseMetricsCallback(transformers.TrainerCallback):
+ """
+ Generic split-aware metric accumulator for HF Trainer.
+
+ Fixes vs old version:
+ 1) Per-split metrics are independent (deep-copied) to avoid train/eval contamination.
+ 2) DDP-safe: sync/compute/reset run on ALL ranks; only rank0 logs/prints (no deadlock).
+ 3) Metrics are moved to trainer device (avoids CPU/GPU mismatch).
+ 4) Optional dtype set if metric supports it.
+
+ Smart key prefixing:
+ - split == "train": no prefix (e.g., "loss", "ppl")
+ - otherwise : f"{split}_" prefix (e.g., "eval_loss", "test_ppl")
+
+ You provide:
+ - metrics_map:
+ * {name: Metric} -> broadcast to all splits (deep-copied per split)
+ * {split: {name: Metric}} -> per-split metrics
+ update():
+ - calls metric.update(*args, **kwargs) by default
+ """
+
+ def __init__(
+ self,
+ trainer: "transformers.Trainer",
+ splits: Iterable[str] = ("train", "eval"),
+ metrics_map: Optional[Dict[str, Any]] = None,
+ dtype: torch.dtype = torch.float64,
+ ):
+ self.trainer = trainer
+ self.splits = tuple(splits)
+ metrics_map = metrics_map or {}
+
+ # Create per-split independent metric dicts
+ self._metrics: Dict[str, Dict[str, torchmetrics.Metric]] = {}
+
+ # Detect broadcast map: {name: Metric}
+ is_broadcast = len(metrics_map) > 0 and all(
+ isinstance(v, torchmetrics.Metric) for v in metrics_map.values()
+ )
+
+ device = getattr(self.trainer.args, "device", torch.device("cpu"))
+
+ for split in self.splits:
+ if is_broadcast:
+ # IMPORTANT: deepcopy so each split has independent state
+ mdict = {k: copy.deepcopy(v) for k, v in metrics_map.items()}
+ else:
+ mdict = {
+ k: copy.deepcopy(v) for k, v in metrics_map.get(split, {}).items()
+ }
+
+ # Configure dtype / device
+ for m in mdict.values():
+ # Many torchmetrics ignore this, but keep your hook
+ if hasattr(m, "set_dtype"):
+ m.set_dtype(dtype)
+ # Ensure state buffers are on the right device
+ try:
+ m.to(device)
+ except Exception:
+ pass
+
+ self._metrics[split] = mdict
+
+ # ---------- key naming ----------
+
+ @staticmethod
+ def key_for(split: str, name: str) -> str:
+ return name if split == "train" else f"{split}_{name}"
+
+ # ---------- lifecycle ----------
+
+ def reset(self, split: str) -> None:
+ for m in self._metrics[split].values():
+ m.reset()
+
+ @torch.no_grad()
+ def update(self, split: str, *args, **kwargs) -> None:
+ for m in self._metrics[split].values():
+ m.update(*args, **kwargs)
+
+ @torch.no_grad()
+ def finalize(self, split: str) -> Dict[str, float]:
+ """
+ DDP-safe finalize:
+ - Must be called on ALL ranks (because sync uses collectives).
+ - Returns local dict of python floats.
+ - Resets split metrics.
+ """
+ mdict = self._metrics[split]
+
+ # Make sure metrics live on current device (in case trainer device changes)
+ device = getattr(self.trainer.args, "device", torch.device("cpu"))
+ for m in mdict.values():
+ try:
+ m.to(device)
+ except Exception:
+ pass
+
+ # Sync across ranks (collectives) -- MUST run on all ranks
+ if _ddp_initialized():
+ for m in mdict.values():
+ # torchmetrics usually has sync/unsync; prefer sync if available
+ if hasattr(m, "sync"):
+ m.sync()
+
+ out: Dict[str, float] = {}
+ for name, m in mdict.items():
+ v = m.compute()
+ if isinstance(v, torch.Tensor):
+ if v.numel() == 0:
+ continue
+ v = v.detach()
+ v = v.item() if v.numel() == 1 else v.double().mean().cpu().item()
+ out[name] = float(v)
+
+ # IMPORTANT: reset after compute so next window starts clean
+ self.reset(split)
+ return out
+
+ @torch.no_grad()
+ def log_and_print(
+ self,
+ state: transformers.TrainerState,
+ splits: Iterable[str] | None = None,
+ ) -> None:
+ """
+ DDP-safe:
+ - finalize() (and thus sync/compute/reset) runs on ALL ranks
+ - only rank0 logs/prints
+ """
+ splits = self.splits if splits is None else tuple(splits)
+
+ # All ranks finalize (avoid DDP deadlock)
+ all_vals: Dict[str, Dict[str, float]] = {}
+ for split in splits:
+ if split in self._metrics:
+ all_vals[split] = self.finalize(split)
+
+ # Only rank0 logs/prints
+ if not self.trainer.is_world_process_zero():
+ return
+
+ logs: Dict[str, float] = {}
+ for split, vals in all_vals.items():
+ logs.update({self.key_for(split, k): v for k, v in vals.items()})
+
+ if logs:
+ self.trainer.log(logs)
+ print(
+ f"[step {state.global_step} epoch {state.epoch}] "
+ + " ".join(f"{k}={v:.6f}" for k, v in logs.items())
+ )
+
+ # ---------- HF callback hooks (optional defaults) ----------
+
+ def on_train_begin(self, args, state, control, **kwargs):
+ if "train" in self._metrics:
+ self.reset("train")
+ return control
+
+ def on_evaluate_begin(self, args, state, control, **kwargs):
+ if "eval" in self._metrics:
+ self.reset("eval")
+ return control
+
+
+class OnEvaluateMetricsCallback(BaseMetricsCallback):
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
+ # Log both train + eval by default (matches your previous behavior).
+ self.log_and_print(state, splits=("train", "eval"))
+ return control
diff --git a/dllm/core/trainers/utils/metrics.py b/dllm/core/trainers/utils/metrics.py
new file mode 100644
index 00000000..926093b6
--- /dev/null
+++ b/dllm/core/trainers/utils/metrics.py
@@ -0,0 +1,12 @@
+import torch
+import torchmetrics
+
+
+class NLLMetric(torchmetrics.aggregation.MeanMetric):
+ pass
+
+
+class PerplexityMetric(NLLMetric):
+ def compute(self) -> torch.Tensor:
+ mean_nll = super().compute()
+ return torch.exp(mean_nll)
diff --git a/dllm/data/s1k.py b/dllm/data/s1k.py
new file mode 100644
index 00000000..a68f87cc
--- /dev/null
+++ b/dllm/data/s1k.py
@@ -0,0 +1,43 @@
+from datasets import DatasetDict, load_dataset
+
+# messages = [
+# {"role": "user", "content": "Solve 13 * 17"},
+# {
+# "role": "assistant",
+# "reasoning_content": "We need to multiply 13 and 17 step by step.",
+# "content": "13 * 17 = 221."
+# }
+# ]
+
+
+def load_dataset_s1k(dataset_name_or_path: str) -> DatasetDict:
+
+ dataset = load_dataset(dataset_name_or_path)
+
+ def map_fn(example):
+
+ return {
+ "messages": [
+ {"role": "user", "content": example["question"]},
+ {
+ "role": "assistant",
+ "reasoning_content": example["thinking_trajectories"][0],
+ "content": example["attempt"],
+ },
+ ]
+ }
+
+ dataset = dataset.map(
+ map_fn, remove_columns=dataset["train"].column_names, num_proc=4
+ )
+ return dataset
+
+
+if __name__ == "__main__":
+ from dllm.utils import resolve_with_base_env
+
+ dataset_name_or_path = resolve_with_base_env(
+ "simplescaling/s1K", "BASE_DATASETS_DIR"
+ )
+ dataset = load_dataset_s1k(dataset_name_or_path)
+ breakpoint()
diff --git a/dllm/pipelines/a2d/convert.py b/dllm/pipelines/a2d/convert.py
index 1b0ffc64..38ddf2a7 100644
--- a/dllm/pipelines/a2d/convert.py
+++ b/dllm/pipelines/a2d/convert.py
@@ -6,7 +6,6 @@
import dllm
A2D_CONFIG_MAP = {
- # "gpt2": dllm.pipelines.a2d.A2DGPT2Config,
"llama": dllm.pipelines.a2d.A2DLlamaConfig,
"qwen2": dllm.pipelines.a2d.A2DQwen2Config,
"qwen3": dllm.pipelines.a2d.A2DQwen3Config,
@@ -17,6 +16,7 @@
class ScriptArguments:
model_name_or_path: str = "Qwen/Qwen2.5-0.5B"
output_dir: str = "models/a2d/Qwen2.5-0.5B"
+ random_init: bool = False
def __post_init__(self):
self.model_name_or_path = dllm.utils.resolve_with_base_env(
@@ -54,10 +54,14 @@ def main():
tgt_config = tgt_config_cls(**cfg_dict)
with dllm.utils.init_device_context_manager():
- # Build A2D model
tgt_model = transformers.AutoModel.from_config(tgt_config)
- # Direct weight copy (models match exactly)
- tgt_model.load_state_dict(src_model.state_dict())
+
+ if not args.random_init:
+ missing, unexpected = tgt_model.load_state_dict(
+ src_model.state_dict(), strict=False
+ )
+ print("missing:", missing)
+ print("unexpected:", unexpected)
# Save model and config
tgt_model.save_pretrained(args.output_dir)
diff --git a/dllm/pipelines/a2d/models/gpt2/modeling_gpt2.py b/dllm/pipelines/a2d/models/gpt2/modeling_gpt2.py
deleted file mode 100644
index ffeffb8d..00000000
--- a/dllm/pipelines/a2d/models/gpt2/modeling_gpt2.py
+++ /dev/null
@@ -1,304 +0,0 @@
-from typing import Optional, Union
-
-import torch
-from torch import nn
-
-import transformers
-from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
-from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
-from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
-from transformers.utils import logging
-from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
-
-from transformers.models.gpt2.modeling_gpt2 import GPT2Attention as GPT2Attention
-
-if transformers.utils.is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE as flex_default_block_size
- from torch.nn.attention.flex_attention import BlockMask, create_block_mask
-else:
- BlockMask = torch.Tensor
-
-logger = logging.get_logger(__name__)
-
-
-class A2DGPT2Config(transformers.GPT2Config):
- model_type = "a2d-gpt2" # <- NEW model_type
-
-
-# >>> A2D modification:
-# Minimal override of GPT2Attention to disable the internal causal mask while
-# keeping all original behavior / comments intact.
-class A2DGPT2Attention(GPT2Attention):
- def __init__(self, config, is_cross_attention=False, layer_idx=None):
- super().__init__(config=config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
-
- # Disable causal behavior for all attention backends (eager/sdpa/flash)
- # This ensures full bidirectional attention as required by A2D.
- self.is_causal = False # <<< key change
-
- # Replace causal lower-triangular mask with an all-True mask
- # so eager/_upcast path will not zero-out future positions.
- if hasattr(self, "bias"):
- full_bias = torch.ones_like(self.bias, dtype=torch.bool)
- self.register_buffer("bias", full_bias, persistent=False)
-
-
-class A2DGPT2Model(transformers.GPT2Model):
-
- def __init__(self, config):
- super().__init__(config)
-
- # >>> A2D modification:
- # Replace original causal GPT2Attention with the non-causal version above.
- for i, block in enumerate(self.h):
- block.attn = A2DGPT2Attention(config, is_cross_attention=False, layer_idx=i)
- if config.add_cross_attention and hasattr(block, "crossattention"):
- block.crossattention = A2DGPT2Attention(config, is_cross_attention=True, layer_idx=i)
-
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- token_type_ids: Optional[torch.LongTensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- head_mask: Optional[torch.FloatTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple, BaseModelOutputWithPastAndCrossAttentions]:
- r"""
- input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
- `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
- `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
- sequence tokens in the vocabulary.
-
- If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
- `input_ids`.
-
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
- [`PreTrainedTokenizer.__call__`] for details.
-
- [What are input IDs?](../glossary#input-ids)
- """
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
- elif input_ids is not None:
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
- input_shape = input_ids.size()
- input_ids = input_ids.view(-1, input_shape[-1])
- batch_size = input_ids.shape[0]
- elif inputs_embeds is not None:
- input_shape = inputs_embeds.size()[:-1]
- batch_size = inputs_embeds.shape[0]
- else:
- raise ValueError("You have to specify either input_ids or inputs_embeds")
-
- device = input_ids.device if input_ids is not None else inputs_embeds.device
-
- if token_type_ids is not None:
- token_type_ids = token_type_ids.view(-1, input_shape[-1])
-
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
- )
- use_cache = False
-
- # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
- if use_cache:
- if past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- elif isinstance(past_key_values, tuple):
- logger.warning_once(
- "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
- "You should pass an instance of `Cache` instead, e.g. "
- "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
- )
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
-
- if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
- past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
-
- if inputs_embeds is None:
- inputs_embeds = self.wte(input_ids)
-
- if cache_position is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- cache_position = torch.arange(
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
- )
- if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- position_embeds = self.wpe(position_ids)
- hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
-
- # Attention mask.
- # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
- if attention_mask is not None and attention_mask.ndim < 4:
- attention_mask = attention_mask.view(batch_size, -1)
-
- # -------------------------------------------------------------
- # ORIGINAL CAUSAL CODE REMOVED BY YOU
- # (kept as comment, no modification)
- # -------------------------------------------------------------
-
- # -------------------------------------------------------------
- # NEW CODE (bidirectional, padding-only mask)
- # (kept exactly as you wrote)
- # -------------------------------------------------------------
- if attention_mask is None:
- attention_mask = torch.ones(
- inputs_embeds.shape[:2],
- device=inputs_embeds.device,
- dtype=torch.long,
- )
-
- if not (
- isinstance(attention_mask, BlockMask)
- or (isinstance(attention_mask, torch.Tensor) and attention_mask.ndim == 4)
- ):
- attention_mask = _prepare_4d_attention_mask(attention_mask, self.dtype)
- # -------------------------------------------------------------
-
- # If a 2D or 3D attention mask is provided for the cross-attention
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
- _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
- if self.config.add_cross_attention and encoder_hidden_states is not None:
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
- if encoder_attention_mask is None:
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
- if _use_sdpa:
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
- mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
- elif self._attn_implementation != "flash_attention_2":
- encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
- else:
- encoder_attention_mask = None
-
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
- if token_type_ids is not None:
- token_type_embeds = self.wte(token_type_ids)
- hidden_states = hidden_states + token_type_embeds
-
- hidden_states = self.drop(hidden_states)
-
- output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
-
- all_self_attentions = () if output_attentions else None
- all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
- all_hidden_states = () if output_hidden_states else None
- for i, block in enumerate(self.h):
- # Model parallel
- if self.model_parallel:
- torch.cuda.set_device(hidden_states.device)
- if isinstance(head_mask, torch.Tensor):
- head_mask = head_mask.to(hidden_states.device)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- outputs = block(
- hidden_states,
- past_key_values if not (self.gradient_checkpointing and self.training) else None,
- cache_position,
- attention_mask, # (unchanged) pass your full-mask 4D mask
- head_mask[i],
- encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- use_cache=use_cache,
- output_attentions=output_attentions,
- **kwargs,
- )
-
- hidden_states = outputs[0]
-
- if output_attentions:
- all_self_attentions = all_self_attentions + (outputs[1],)
- if self.config.add_cross_attention:
- all_cross_attentions = all_cross_attentions + (outputs[2],)
-
- # Model Parallel: If it's the last layer for that device, put things on the next device
- if self.model_parallel:
- for k, v in self.device_map.items():
- if i == v[-1] and "cuda:" + str(k) != self.last_device:
- hidden_states = hidden_states.to("cuda:" + str(k + 1))
-
- hidden_states = self.ln_f(hidden_states)
-
- hidden_states = hidden_states.view(output_shape)
- # Add last hidden state
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
-
- past_key_values = past_key_values if use_cache else None
- if not return_dict:
- return tuple(
- v
- for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
- if v is not None
- )
-
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- cross_attentions=all_cross_attentions,
- )
-
-
-class A2DGPT2LMHeadModel(transformers.GPT2LMHeadModel):
- config: A2DGPT2Config
-
- def __init__(self, config):
- transformers.GPT2PreTrainedModel.__init__(self, config)
- self.transformer = A2DGPT2Model(config)
- self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
-
- self.model_parallel = False
- self.device_map = None
- self.post_init()
-
-
-transformers.AutoConfig.register("a2d-gpt2", A2DGPT2Config)
-transformers.AutoModel.register(A2DGPT2Config, A2DGPT2LMHeadModel)
-transformers.AutoModelForMaskedLM.register(A2DGPT2Config, A2DGPT2LMHeadModel)
-
-
-if __name__ == "__main__":
- import dllm
- import torch
- from transformers import AutoModel
-
- # Load a config from a local path (either a directory containing config.json, or the file itself)
- config_path = dllm.utils.resolve_with_base_env(
- "openai-community/gpt2", "BASE_MODELS_DIR"
- )
- config = A2DGPT2Config.from_pretrained(config_path)
- if hasattr(config, "auto_map"):
- delattr(config, "auto_map")
- if hasattr(config, "architectures"):
- delattr(config, "architectures")
-
- torch.set_default_device("cuda")
- model = A2DGPT2LMHeadModel(config)
- model.save_pretrained("models-tmp/a2d-gpt2")
- auto_model = AutoModel.from_pretrained("models-tmp/a2d-gpt2")
diff --git a/dllm/pipelines/dream/eval.py b/dllm/pipelines/dream/eval.py
index 9309aa09..ac1b52ca 100644
--- a/dllm/pipelines/dream/eval.py
+++ b/dllm/pipelines/dream/eval.py
@@ -214,9 +214,9 @@ def generate_until(
# tokenize
prompt_ids = [
- self.tokenizer(
- p, return_tensors="pt", padding=False
- ).input_ids.squeeze().to(self.device)
+ self.tokenizer(p, return_tensors="pt", padding=False)
+ .input_ids.squeeze()
+ .to(self.device)
for p in prompts
]
prompt_lens = [len(p_id) for p_id in prompt_ids]
diff --git a/dllm/pipelines/dream/trainer.py b/dllm/pipelines/dream/trainer.py
index f32e1386..f1c6066b 100644
--- a/dllm/pipelines/dream/trainer.py
+++ b/dllm/pipelines/dream/trainer.py
@@ -1,4 +1,5 @@
from typing import Any
+from dataclasses import dataclass
import torch
@@ -6,21 +7,21 @@
def cart_weight(
- masked_indices: torch.Tensor, t: torch.Tensor, p: float = 0.3
+ masked_mask: torch.Tensor, t: torch.Tensor, p: float = 0.3
) -> torch.Tensor:
"""
Optimized CART weight computation using matrix operations.
Args:
- masked_indices (torch.Tensor): (b, l) bool tensor indicating masked positions.
+ masked_mask (torch.Tensor): (b, l) bool tensor indicating masked positions.
t (torch.Tensor): (b,) time steps (0-1 sampled uniformly). Not directly used in CART.
p (float): Parameter of geometric distribution (0 < p <= 1).
Returns:
torch.Tensor: (b, l) float tensor of weights.
"""
- b, l = masked_indices.shape
- device = masked_indices.device
+ b, l = masked_mask.shape
+ device = masked_mask.device
idx = torch.arange(l, device=device)
dist_matrix = (idx[None, :] - idx[:, None]).abs() - 1
@@ -31,9 +32,9 @@ def cart_weight(
).exp() * 0.5 # Ensure numerical stability
geo_matrix.masked_fill_(dist_matrix == 0, 0.0) # ignore distance = 0
- valid_mask = (~masked_indices).float() # (b, l), 1 = unmasked
+ valid_mask = (~masked_mask).float() # (b, l), 1 = unmasked
weights = valid_mask @ geo_matrix.T # (b, l)
- weights = weights * masked_indices.float()
+ weights = weights * masked_mask.float()
return weights
@@ -42,25 +43,20 @@ class DreamTrainer(MDLMTrainer):
DreamTrainer: specialization of MDLMTrainer for Dream training.
"""
- def __init__(
- self,
- loss_weight_type: str = "cart[geo_p:0.3]",
- *args,
- **kwargs,
- ):
- super().__init__(
- loss_weight_type=loss_weight_type,
- *args,
- **kwargs,
- )
+ @dataclass
+ class DreamConfig(MDLMTrainer.MDLMConfig):
+ loss_weight_type: str = "cart[geo_p:0.3]"
+ right_shift_logits: bool = True
- self.right_shift_logits = True
+ def __post_init__(self):
+ super().__post_init__()
+ assert self.right_shift_logits
def _compute_loss_weights(
self,
t: torch.Tensor,
inputs: dict[str, Any],
- masked_indices: torch.Tensor,
+ masked_mask: torch.Tensor,
*args,
**kwargs,
) -> torch.Tensor:
@@ -70,12 +66,12 @@ def _compute_loss_weights(
match = re.search(r"geo_p:(0\.\d+)", self.loss_weight_type)
geo_p = float(match.group(1)) if match else 0.3
- loss_weights = cart_weight(masked_indices, t, p=geo_p)
+ loss_weights = cart_weight(masked_mask, t, p=geo_p)
else:
loss_weights = super()._compute_loss_weights(
t=t,
inputs=inputs,
- masked_indices=masked_indices,
+ masked_mask=masked_mask,
*args,
**kwargs,
)
diff --git a/dllm/pipelines/editflow/trainer.py b/dllm/pipelines/editflow/trainer.py
index 2635539f..285da50f 100644
--- a/dllm/pipelines/editflow/trainer.py
+++ b/dllm/pipelines/editflow/trainer.py
@@ -8,6 +8,7 @@
from dllm.core.schedulers import BaseKappaScheduler, CubicKappaScheduler
from dllm.pipelines.editflow.utils import pad_1d
+from dllm.utils.configs import TrainingArguments
BLANK = -1
@@ -211,20 +212,25 @@ class EditFlowTrainer(transformers.Trainer):
True intensities are w * rate_hat, with w = kappa_dot(t) / (1 - kappa(t)).
"""
+ @dataclass
+ class EditFlowConfig(TrainingArguments):
+ time_epsilon: float = 1e-3
+ normalize_per_position: bool = True
+ max_w: float = 20.0
+
def __init__(
self,
- *args,
+ args: EditFlowConfig,
scheduler: BaseKappaScheduler | None = None,
- normalize_per_position: bool = True,
- time_epsilon: float = 1e-3,
- max_w: float | None = None,
+ *pargs,
**kwargs,
):
- self.scheduler = scheduler or CubicKappaScheduler()
- self.normalize_per_position = normalize_per_position
- self.time_epsilon = time_epsilon
- self.max_w = max_w
- super().__init__(*args, **kwargs)
+ super().__init__(args=args, *pargs, **kwargs)
+
+ self.scheduler = scheduler if scheduler is not None else CubicKappaScheduler()
+ self.time_epsilon = args.time_epsilon
+ self.normalize_per_position = args.normalize_per_position
+ self.max_w = args.max_w
def compute_loss(
self,
diff --git a/dllm/pipelines/llada/eval.py b/dllm/pipelines/llada/eval.py
index 9cb1fe4a..cf46007b 100644
--- a/dllm/pipelines/llada/eval.py
+++ b/dllm/pipelines/llada/eval.py
@@ -364,9 +364,7 @@ def generate_until(self, requests: list[Instance]) -> list[str]:
for instance in tqdm(requests, desc="Generating..."):
context, gen_kwargs = instance.args # type: ignore
prompt_ids = self.tokenizer(context)["input_ids"]
- prompt = [
- torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
- ]
+ prompt = [torch.tensor(prompt_ids, device=self.device, dtype=torch.long)]
stop_tokens = gen_kwargs["until"]
generated_ids = sampler.sample(
inputs=prompt,
diff --git a/dllm/tools/preprocess_sft_dataset.py b/dllm/tools/preprocess_sft_dataset.py
index b0b65d20..be8ffdfb 100644
--- a/dllm/tools/preprocess_sft_dataset.py
+++ b/dllm/tools/preprocess_sft_dataset.py
@@ -25,7 +25,7 @@ class ScriptArguments:
"""Preprocess SFT dataset."""
model_name_or_path: str = "GSAI-ML/LLaDA-8B-Base"
- sft_map_fn_path: str = "dllm.utils.default_mdlm_sft_map_fn"
+ sft_map_fn_path: str = "dllm.utils.default_sft_map_fn"
dataset_args: str = "HuggingFaceTB/smoltalk" # required
output_dir: str = "data/sft/llada/smoltalk" # required
mask_prompt_loss: bool = True # Mask prompt tokens in labels with -100
diff --git a/dllm/utils/configs.py b/dllm/utils/configs.py
index f8e84359..b9e7916c 100644
--- a/dllm/utils/configs.py
+++ b/dllm/utils/configs.py
@@ -70,6 +70,7 @@ class TrainingArguments(transformers.TrainingArguments):
def __post_init__(self):
super().__post_init__()
+ self.run_name = self.run_name or self.output_dir
if self.group_by_length:
logger.info(
"training_args.group_by_length=True: preprocessing "
diff --git a/dllm/utils/data.py b/dllm/utils/data.py
index 39d7794f..d69d6b7d 100644
--- a/dllm/utils/data.py
+++ b/dllm/utils/data.py
@@ -222,7 +222,7 @@ def clip_right(row):
raise NotImplementedError
-def default_mdlm_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
+def default_sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
"""
Build input_ids and labels for SFT.
diff --git a/dllm/utils/models.py b/dllm/utils/models.py
index 4878ad32..77460f0d 100644
--- a/dllm/utils/models.py
+++ b/dllm/utils/models.py
@@ -140,11 +140,7 @@ def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer:
{% endif %}
"""
- elif issubclass(model_cls, LLaDAMoEModelLM):
- tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
- tokenizer.eot_token = "<|role_end|>"
- tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
- elif issubclass(model_cls, LLaDA2MoeModelLM):
+ elif issubclass(model_cls, (LLaDAMoEModelLM, LLaDA2MoeModelLM)):
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
tokenizer.eot_token = "<|role_end|>"
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
@@ -187,15 +183,10 @@ def get_tokenizer(model_args) -> transformers.PreTrainedTokenizer:
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
tokenizer.eot_token = "<|eot_id|>"
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
- elif issubclass(model_cls, A2DQwen2LMHeadModel):
- tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
- tokenizer.eot_token = "<|im_end|>"
- tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
- elif issubclass(model_cls, A2DQwen3LMHeadModel):
+ elif issubclass(model_cls, (A2DQwen2LMHeadModel, A2DQwen3LMHeadModel)):
tokenizer.add_special_tokens({"mask_token": "<|mask|>"})
tokenizer.eot_token = "<|im_end|>"
tokenizer.eot_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eot_token)
- tokenizer.chat_template = "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\n' }}\n {{- '\n\n\n\n' }}\n{%- endif %}"
else:
print_main("no tokenizer customization for model class:", model_cls)
return tokenizer
diff --git a/examples/a2d/bd3lm/pt.py b/examples/a2d/bd3lm/pt.py
index bc5dbe81..7866a79c 100644
--- a/examples/a2d/bd3lm/pt.py
+++ b/examples/a2d/bd3lm/pt.py
@@ -60,15 +60,14 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.BD3LMTrainer.BD3LMConfig):
output_dir: str = "models/a2d/Qwen3-0.6B/bd3lm/tiny-shakespeare"
num_train_epochs: int = 20
learning_rate: float = 1e-4
per_device_train_batch_size: int = 16
per_device_eval_batch_size: int = 16
- # a2d-specific
+ # bd3lm
block_size: int = 32
- right_shift_logits: bool = False
def train():
@@ -123,8 +122,6 @@ def train():
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
- block_size=training_args.block_size,
- right_shift_logits=training_args.right_shift_logits,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
diff --git a/examples/a2d/bd3lm/sft.py b/examples/a2d/bd3lm/sft.py
index b97b329d..63e80053 100644
--- a/examples/a2d/bd3lm/sft.py
+++ b/examples/a2d/bd3lm/sft.py
@@ -55,16 +55,15 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.BD3LMTrainer.BD3LMConfig):
output_dir: str = "models/a2d/Qwen3-0.6B/mdlm/alpaca"
group_by_length: bool = True
num_train_epochs: int = 20
learning_rate: float = 1e-4
per_device_train_batch_size: int = 16
per_device_eval_batch_size: int = 16
- # a2d-specific
+ # bd3lm
block_size: int = 32
- right_shift_logits: bool = False
def train():
@@ -89,7 +88,7 @@ def train():
)
if not data_args.load_preprocessed_data:
map_fn = partial(
- dllm.utils.default_mdlm_sft_map_fn,
+ dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
@@ -110,8 +109,6 @@ def train():
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
- block_size=training_args.block_size,
- right_shift_logits=training_args.right_shift_logits,
data_collator=(
dllm.core.trainers.bd3lm.AppendEOSBlockWrapper(
transformers.DataCollatorForSeq2Seq(
diff --git a/examples/a2d/mdlm/pt.py b/examples/a2d/mdlm/pt.py
index 4554ef83..43d86f2c 100644
--- a/examples/a2d/mdlm/pt.py
+++ b/examples/a2d/mdlm/pt.py
@@ -60,14 +60,12 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig):
output_dir: str = "models/a2d/Qwen3-0.6B/mdlm/tiny-shakespeare"
num_train_epochs: int = 20
learning_rate: float = 1e-4
per_device_train_batch_size: int = 16
per_device_eval_batch_size: int = 16
- # a2d-specific
- right_shift_logits: bool = False
def train():
@@ -122,7 +120,6 @@ def train():
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
- right_shift_logits=training_args.right_shift_logits,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
diff --git a/examples/a2d/mdlm/sft.py b/examples/a2d/mdlm/sft.py
index 69c7582c..f74ef33e 100644
--- a/examples/a2d/mdlm/sft.py
+++ b/examples/a2d/mdlm/sft.py
@@ -55,15 +55,13 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig):
output_dir: str = "models/a2d/Qwen3-0.6B/mdlm/alpaca"
group_by_length: bool = True
learning_rate: float = 1e-4
num_train_epochs: int = 20
per_device_train_batch_size: int = 16
per_device_eval_batch_size: int = 16
- # a2d-specific
- right_shift_logits: bool = False
def train():
@@ -88,7 +86,7 @@ def train():
)
if not data_args.load_preprocessed_data:
map_fn = partial(
- dllm.utils.default_mdlm_sft_map_fn,
+ dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
@@ -109,7 +107,6 @@ def train():
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
- right_shift_logits=training_args.right_shift_logits,
data_collator=(
dllm.utils.NoAttentionMaskWrapper( # padded should be visible
transformers.DataCollatorForSeq2Seq(
diff --git a/examples/bert/pt.py b/examples/bert/pt.py
index 302d6b34..39cb906b 100644
--- a/examples/bert/pt.py
+++ b/examples/bert/pt.py
@@ -60,7 +60,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig):
output_dir: str = "models/ModernBERT-base/tiny-shakespeare"
num_train_epochs: int = 20
learning_rate: float = 1e-4
diff --git a/examples/bert/sft.py b/examples/bert/sft.py
index 038b9c44..878cfca2 100644
--- a/examples/bert/sft.py
+++ b/examples/bert/sft.py
@@ -55,7 +55,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig):
output_dir: str = "models/ModernBERT-large/alpaca"
group_by_length: bool = True
num_train_epochs: int = 20
@@ -86,7 +86,7 @@ def train():
)
if not data_args.load_preprocessed_data:
map_fn = partial(
- dllm.utils.default_mdlm_sft_map_fn,
+ dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
diff --git a/examples/dream/README.md b/examples/dream/README.md
index ce406ebb..7bcb34d3 100644
--- a/examples/dream/README.md
+++ b/examples/dream/README.md
@@ -88,7 +88,7 @@ We tried our best to reproduce [`Dream-v0-Instruct-7B`](https://huggingface.co/D
# Preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "Dream-org/Dream-v0-Base-7B" \
- --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \
+ --sft_map_fn_path "dllm.utils.default_sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/dream/tulu-3-sft-mixture" \
--num_proc 64
diff --git a/examples/dream/pt.py b/examples/dream/pt.py
index 9a44a33e..5247213c 100644
--- a/examples/dream/pt.py
+++ b/examples/dream/pt.py
@@ -65,7 +65,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.pipelines.dream.DreamTrainer.DreamConfig):
output_dir: str = (
"models/Dream-v0-Base-7B/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
@@ -142,7 +142,6 @@ def train():
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
- loss_weight_type=training_args.loss_weight_type,
data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
diff --git a/examples/dream/sft.py b/examples/dream/sft.py
index a48aa9e9..98a44eb1 100644
--- a/examples/dream/sft.py
+++ b/examples/dream/sft.py
@@ -75,7 +75,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.pipelines.dream.DreamTrainer.DreamConfig):
output_dir: str = (
"models/Dream-v0-Base-7B/tulu-3-sft-mixture[train:10000,test:1000]"
)
@@ -120,7 +120,7 @@ def train():
)
if not data_args.load_preprocessed_data:
map_fn = partial(
- dllm.utils.default_mdlm_sft_map_fn,
+ dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
@@ -141,7 +141,6 @@ def train():
train_dataset=dataset["train"],
eval_dataset=dataset.get("test", None),
args=training_args,
- loss_weight_type=training_args.loss_weight_type,
data_collator=dream.utils.DreamSFTCollator(
tokenizer,
return_tensors="pt",
diff --git a/examples/editflow/pt.py b/examples/editflow/pt.py
index 0cbb463a..a6b0927c 100644
--- a/examples/editflow/pt.py
+++ b/examples/editflow/pt.py
@@ -3,6 +3,7 @@
from dataclasses import dataclass, field
import accelerate
+import transformers
import dllm
from dllm.pipelines import editflow
@@ -31,7 +32,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(editflow.EditFlowTrainer.EditFlowConfig):
output_dir: str = None # overwrite this
num_train_epochs: int = 10
learning_rate: float = 1e-4
@@ -43,18 +44,10 @@ class TrainingArguments(dllm.utils.TrainingArguments):
metadata={
"help": (
"The scheduler class controlling κ(t). "
- "Available options: see `dllm/utils/schedulers/kappa.py`"
+ "Available options: see `dllm/core/schedulers/kappa.py`"
)
},
)
- normalize_per_position: bool = field(
- default=True,
- metadata={"help": "Whether to normalize the loss per position."},
- )
- max_w: float = field(
- default=20.0,
- metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."},
- )
x0_sampler: str = field(
default="masks[length:128]",
metadata={
@@ -66,11 +59,12 @@ class TrainingArguments(dllm.utils.TrainingArguments):
)
-def train(
- model_args: ModelArguments,
- data_args: DataArguments,
- training_args: TrainingArguments,
-):
+def train():
+ # ----- Argument parsing -------------------------------------------------------
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments)
+ )
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary when batch does not contain "labels" field
training_args.label_names = []
# necessary when batch contains customized fields
@@ -129,11 +123,13 @@ def _no_flops(*args, **kwargs):
scheduler=dllm.core.schedulers.make_kappa_scheduler(
training_args.scheduler_cls
),
- normalize_per_position=training_args.normalize_per_position,
- max_w=training_args.max_w,
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/examples/editflow/sft.py b/examples/editflow/sft.py
index 1daff739..c7ec1cbf 100644
--- a/examples/editflow/sft.py
+++ b/examples/editflow/sft.py
@@ -3,6 +3,7 @@
from functools import partial
import accelerate
+import transformers
import dllm
from dllm.pipelines import editflow
@@ -26,7 +27,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(editflow.EditFlowTrainer.EditFlowConfig):
output_dir: str = None # overwrite this
num_train_epochs: float = 10
learning_rate: float = 1e-4
@@ -38,18 +39,10 @@ class TrainingArguments(dllm.utils.TrainingArguments):
metadata={
"help": (
"The scheduler class controlling κ(t). "
- "Available options: see `dllm/utils/schedulers/kappa.py`"
+ "Available options: see `dllm/core/schedulers/kappa.py`"
)
},
)
- normalize_per_position: bool = field(
- default=True,
- metadata={"help": "Whether to normalize the loss per position."},
- )
- max_w: float = field(
- default=20.0,
- metadata={"help": "The maximum weight (κ'(t) / (1 - κ(t))) for the loss."},
- )
x0_sampler: str = field(
default="masks[length:64]",
metadata={
@@ -88,11 +81,12 @@ def sft_map_fn(row, *, tokenizer, mask_prompt_loss: bool = True) -> dict:
return {"input_ids": prompt_response_tokens}
-def train(
- model_args: ModelArguments,
- data_args: DataArguments,
- training_args: TrainingArguments,
-):
+def train():
+ # ----- Argument parsing -------------------------------------------------------
+ parser = transformers.HfArgumentParser(
+ (ModelArguments, DataArguments, TrainingArguments)
+ )
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# necessary when batch does not contain "labels" field
training_args.label_names = []
# necessary when batch contains customized fields
@@ -146,11 +140,13 @@ def _no_flops(*args, **kwargs):
scheduler=dllm.core.schedulers.make_kappa_scheduler(
training_args.scheduler_cls
),
- normalize_per_position=training_args.normalize_per_position,
- max_w=training_args.max_w,
)
trainer.train()
trainer.save_model(os.path.join(training_args.output_dir, "checkpoint-final"))
trainer.processing_class.save_pretrained(
os.path.join(training_args.output_dir, "checkpoint-final")
)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/examples/llada/README.md b/examples/llada/README.md
index d8626d31..494fa22f 100644
--- a/examples/llada/README.md
+++ b/examples/llada/README.md
@@ -109,7 +109,7 @@ Though LLaDA is trained on proprietary data, we tried our best to reproduce [`LL
# Preprocessing SFT data (optional, but can avoid redundant preprocessing for multi-node training)
python dllm/tools/preprocess_sft_dataset.py \
--model_name_or_path "GSAI-ML/LLaDA-8B-Base" \
- --sft_map_fn_path "dllm.utils.default_mdlm_sft_map_fn" \
+ --sft_map_fn_path "dllm.utils.default_sft_map_fn" \
--dataset_args "allenai/tulu-3-sft-mixture" \
--output_dir "data/sft/llada/tulu-3-sft-mixture" \
--num_proc 64
diff --git a/examples/llada/pt.py b/examples/llada/pt.py
index 246f7097..06243d55 100644
--- a/examples/llada/pt.py
+++ b/examples/llada/pt.py
@@ -67,7 +67,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig):
output_dir: str = (
"models/LLaDA-8B-Base/dclm-baseline-1.0[train:10_000_000,test:10_000]"
)
diff --git a/examples/llada/sft.py b/examples/llada/sft.py
index e0830e5f..b3a9f1e0 100644
--- a/examples/llada/sft.py
+++ b/examples/llada/sft.py
@@ -55,7 +55,7 @@ class DataArguments(dllm.utils.DataArguments):
@dataclass
-class TrainingArguments(dllm.utils.TrainingArguments):
+class TrainingArguments(dllm.core.trainers.MDLMTrainer.MDLMConfig):
output_dir: str = "models/LLaDA-8B-Base/tulu-3-sft-mixture[train:10000,test:1000]"
group_by_length: bool = True
num_train_epochs: float = 5
@@ -86,7 +86,7 @@ def train():
)
if not data_args.load_preprocessed_data:
map_fn = partial(
- dllm.utils.default_mdlm_sft_map_fn,
+ dllm.utils.default_sft_map_fn,
tokenizer=tokenizer,
mask_prompt_loss=data_args.mask_prompt_loss,
)
diff --git a/pyproject.toml b/pyproject.toml
index 9f134248..288d2284 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,6 +18,7 @@ dependencies = [
"peft==0.17.1",
"datasets==4.2.0",
"sentencepiece==0.2.0",
+ "torchmetrics",
"tyro",
"wandb",
"omegaconf",
@@ -33,6 +34,10 @@ optional = [
"vllm==0.8.5.post1",
"flash-attn==2.8.3",
]
+rl = [
+ "trl==0.26.0",
+ "math_verify==0.8.0",
+]
[tool.black]
line-length = 88
diff --git a/scripts/tests/test_attention.py b/scripts/tests/test_attention.py
index 4a690426..e1819ca1 100644
--- a/scripts/tests/test_attention.py
+++ b/scripts/tests/test_attention.py
@@ -662,7 +662,7 @@ def test_bd3lm_staircase_attention_kvcache_equivalence(
(A) in one full 8-token forward pass
(B) in two incremental passes (4 tokens → KV cache → 4 tokens)
"""
- from dllm.core.samplers.bd3lm import build_staircase_attention_mask
+ from dllm.core.samplers.bd3lm import _prepare_for_sampling
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -699,7 +699,7 @@ def test_bd3lm_staircase_attention_kvcache_equivalence(
# ------------------------------
# 3. Build staircase mask + positions for the full sequence
# ------------------------------
- attn_full, pos_full = build_staircase_attention_mask(
+ attn_full, pos_full = _prepare_for_sampling(
x_full, block_size=block_size, pad_token_id=pad_token_id
)
# attn_full: [1, 1, 8, 8]
@@ -808,7 +808,7 @@ def test_bd3lm_concat_equivalence_when_noised_equals_input(
NOTE: We set block_size == seq_len so x_t tokens attend only within x_t (single block),
making the first-half computation equivalent to a standard full-attention forward.
"""
- from dllm.core.trainers.bd3lm import block_diff_mask
+ from dllm.core.trainers.bd3lm import _bd3lm_attention_mask
torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -858,7 +858,7 @@ def test_bd3lm_concat_equivalence_when_noised_equals_input(
pos_cat = torch.cat([pos, pos], dim=1) # [1, 2L]
L2 = 2 * seq_len
- attn_bd = block_diff_mask(
+ attn_bd = _bd3lm_attention_mask(
b=None,
h=None,
q_idx=torch.arange(L2, device=device)[:, None],